diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml
index 79b15d9..b171e27 100644
--- a/.github/workflows/test-lint.yml
+++ b/.github/workflows/test-lint.yml
@@ -26,6 +26,11 @@ jobs:
python-version: '3.12'
cache: 'pip'
+ - name: Install system dependencies (OpenGL for MuJoCo)
+ run: |
+ sudo apt-get update
+ sudo apt-get install -y libosmesa6-dev
+
- name: Install dependencies
run: |
pip install --no-cache-dir hatch
@@ -35,4 +40,6 @@ jobs:
run: hatch run lint
- name: Run tests
+ env:
+ MUJOCO_GL: osmesa
run: hatch run test -x --strict-markers
diff --git a/README.md b/README.md
index 0a93a93..a299d01 100644
--- a/README.md
+++ b/README.md
@@ -486,30 +486,22 @@ while True:
agent.tool.gr00t_inference(action="stop", port=8000)
```
-## Configuration
-
-### Environment Variables
+## Environment Variables
| Variable | Description | Default |
|----------|-------------|---------|
-| `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (MJCF, meshes) | `~/.strands_robots/assets/` |
-| `GROOT_API_TOKEN` | API token for GR00T inference service | — |
-
-### Cache Directory
+| `STRANDS_ROBOT_MODE` | Override auto-detection mode (`sim` or `real`) | (auto-detect) |
+| `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (URDF/MJCF/meshes) | `~/.strands_robots/assets/` |
+| `STRANDS_URDF_DIR` | **Deprecated** — use `STRANDS_ASSETS_DIR` instead | — |
+| `STRANDS_TRUST_REMOTE_CODE` | Set to `1` to allow loading remote HuggingFace policies | (disabled) |
+| `GROOT_API_TOKEN` | API token for NVIDIA GR00T cloud inference | — |
+| `MUJOCO_GL` | OpenGL backend for MuJoCo rendering (`egl`, `osmesa`, `glfw`) | Auto-configured on headless Linux |
-Robot model assets (MJCF XML files and meshes) are cached in:
+**Cache directory:** Robot model assets (URDF, MJCF, meshes) are downloaded on first use to `~/.strands_robots/assets/`. Override with `STRANDS_ASSETS_DIR`. To clear cached assets:
+```bash
+rm -rf ~/.strands_robots/assets/
```
-~/.strands_robots/
-└── assets/ # Downloaded robot models (from robot_descriptions / MuJoCo Menagerie)
- ├── trs_so_arm100/
- ├── franka_emika_panda/
- └── ...
-```
-
-To clear the cache: `rm -rf ~/.strands_robots/assets/`
-
-To change the cache location: `export STRANDS_ASSETS_DIR=/path/to/custom/dir`
## Contributing
diff --git a/pyproject.toml b/pyproject.toml
index f1a7090..d9ae441 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -48,13 +48,16 @@ groot-service = [
lerobot = [
"lerobot>=0.5.0,<0.6.0",
]
-sim = [
+sim-mujoco = [
"robot_descriptions>=1.11.0,<2.0.0",
+ "mujoco>=3.0.0,<4.0.0",
+ "imageio>=2.28.0,<3.0.0",
+ "imageio-ffmpeg>=0.4.0,<1.0.0",
]
all = [
"strands-robots[groot-service]",
"strands-robots[lerobot]",
- "strands-robots[sim]",
+ "strands-robots[sim-mujoco]",
]
dev = [
"pytest>=6.0,<9.0.0",
@@ -128,7 +131,7 @@ ignore_missing_imports = false
# Third-party libs without type stubs
[[tool.mypy.overrides]]
-module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*"]
+module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*", "imageio.*"]
ignore_missing_imports = true
# @tool decorator injects runtime signatures mypy cannot check
@@ -161,6 +164,22 @@ module = ["strands_robots.registry.*"]
warn_return_any = false
disallow_untyped_defs = false
+# MuJoCo simulation — mixins use cooperative self._world patterns
+# attr-defined: Mixins access self._world/self._lock/etc. from Simulation (cooperative pattern)
+# assignment: PEP 484 implicit Optional (= None on typed params)
+# override: Subclass signatures extend base with extra params (orientation, mesh_path)
+# misc: Multiple inheritance method resolution conflicts between mixin + ABC
+[[tool.mypy.overrides]]
+module = ["strands_robots.simulation.mujoco.*"]
+disallow_untyped_defs = false
+warn_return_any = false
+
+# Async utils and dataset recorder — thin wrappers with dynamic types
+[[tool.mypy.overrides]]
+module = ["strands_robots._async_utils", "strands_robots.dataset_recorder"]
+disallow_untyped_defs = false
+warn_return_any = false
+
# Test files — relaxed type checking for mocks, fixtures, and test utilities
[[tool.mypy.overrides]]
module = ["tests.*", "tests_integ.*"]
diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py
index 8ee9c41..40dc6c2 100644
--- a/strands_robots/__init__.py
+++ b/strands_robots/__init__.py
@@ -11,11 +11,12 @@
- Clean separation between robot control and policy inference
- Direct policy injection for maximum flexibility
- Multi-camera support with rich configuration options
+- MuJoCo simulation backend (no GPU required)
Lazy Loading:
- Heavy imports (Robot, tools, Gr00tPolicy) are deferred until first access.
- Heavy imports are deferred so ``import strands_robots`` stays fast when lerobot/torch
- are installed but not yet needed.
+ Heavy imports (Robot, tools, Gr00tPolicy, Simulation) are deferred until
+ first access. Heavy imports are deferred so ``import strands_robots`` stays
+ fast when lerobot/torch/mujoco are installed but not yet needed.
Light-weight symbols (Policy, MockPolicy, create_policy) are available
immediately since they don't pull in torch/lerobot.
@@ -26,7 +27,7 @@
from typing import Any
# ------------------------------------------------------------------
-# Light-weight imports — no torch / lerobot dependency
+# Light-weight imports — no torch / lerobot / mujoco dependency
# ------------------------------------------------------------------
from strands_robots.policies import MockPolicy, Policy, create_policy # noqa: F401
@@ -35,8 +36,21 @@
# ------------------------------------------------------------------
# Maps public name -> (module_path, attribute_name)
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
+ # Hardware robot
"Robot": ("strands_robots.robot", "Robot"),
+ "list_robots": ("strands_robots.registry", "list_robots"),
+ # Policies
"Gr00tPolicy": ("strands_robots.policies.groot", "Gr00tPolicy"),
+ # Simulation (MuJoCo)
+ "Simulation": ("strands_robots.simulation", "Simulation"),
+ "create_simulation": ("strands_robots.simulation.factory", "create_simulation"),
+ "list_backends": ("strands_robots.simulation.factory", "list_backends"),
+ "register_backend": ("strands_robots.simulation.factory", "register_backend"),
+ "SimWorld": ("strands_robots.simulation", "SimWorld"),
+ "SimRobot": ("strands_robots.simulation", "SimRobot"),
+ "SimObject": ("strands_robots.simulation", "SimObject"),
+ "SimCamera": ("strands_robots.simulation", "SimCamera"),
+ # Tools
"gr00t_inference": ("strands_robots.tools.gr00t_inference", "gr00t_inference"),
"lerobot_calibrate": ("strands_robots.tools.lerobot_calibrate", "lerobot_calibrate"),
"lerobot_camera": ("strands_robots.tools.lerobot_camera", "lerobot_camera"),
@@ -53,6 +67,11 @@
# Lazy-loaded
"Robot",
"Gr00tPolicy",
+ "Simulation",
+ "SimWorld",
+ "SimRobot",
+ "SimObject",
+ "SimCamera",
"gr00t_inference",
"lerobot_camera",
"lerobot_teleoperate",
@@ -62,12 +81,32 @@
]
+# Auto-configure MuJoCo GL backend for headless environments BEFORE any
+# module imports mujoco at the top level. MuJoCo locks the OpenGL backend
+# at import time, so MUJOCO_GL must be set first.
+#
+# WHY EAGER: This MUST run at module import time, not lazily, because:
+# 1. MuJoCo reads MUJOCO_GL only on first `import mujoco`
+# 2. Any downstream code doing `from strands_robots.simulation import ...`
+# triggers mujoco import via the lazy-load chain
+# 3. If we defer to first use, the env var would be set too late
+# This is the canonical location — strands_robots/simulation/__init__.py
+# intentionally does NOT duplicate this call.
+try:
+ from strands_robots.simulation.mujoco.backend import _configure_gl_backend
+
+ _configure_gl_backend()
+except (ImportError, AttributeError, OSError):
+ pass
+
+
def __getattr__(name: str) -> Any: # noqa: N807
"""Lazy-load heavy modules on first attribute access.
- This avoids importing torch, lerobot, numpy, pyserial, etc. at
+ This avoids importing torch, lerobot, numpy, mujoco, pyserial, etc. at
``import strands_robots`` time. The first access to e.g.
- ``strands_robots.Robot`` triggers the real import.
+ ``strands_robots.Robot`` or ``strands_robots.Simulation`` triggers the
+ real import.
"""
if name in _LAZY_IMPORTS:
module_path, attr_name = _LAZY_IMPORTS[name]
diff --git a/strands_robots/_async_utils.py b/strands_robots/_async_utils.py
new file mode 100644
index 0000000..dc9ac8d
--- /dev/null
+++ b/strands_robots/_async_utils.py
@@ -0,0 +1,36 @@
+"""Async-to-sync helper for resolving coroutines in sync contexts."""
+
+import asyncio
+import atexit
+from concurrent.futures import ThreadPoolExecutor
+
+# Module-level executor reused across calls to avoid creating threads at high frequency.
+# A single worker is sufficient — we only need to offload one asyncio.run() at a time.
+_EXECUTOR = ThreadPoolExecutor(max_workers=1, thread_name_prefix="strands_async")
+
+# Ensure the executor is shut down cleanly on interpreter exit to avoid
+# ResourceWarning and orphaned threads.
+atexit.register(_EXECUTOR.shutdown, wait=False)
+
+
+def _resolve_coroutine(coro_or_result): # type: ignore[no-untyped-def]
+ """Safely resolve a potentially-async result to a sync value.
+
+ Handles three cases:
+ 1. Already a plain value → return as-is
+ 2. Coroutine, no running loop → asyncio.run()
+ 3. Coroutine, inside running loop → offload to reused thread
+
+ Args:
+ coro_or_result: Either a coroutine or an already-resolved value.
+
+ Returns:
+ The resolved (sync) value.
+ """
+ if not asyncio.iscoroutine(coro_or_result):
+ return coro_or_result
+ try:
+ asyncio.get_running_loop()
+ return _EXECUTOR.submit(asyncio.run, coro_or_result).result()
+ except RuntimeError:
+ return asyncio.run(coro_or_result)
diff --git a/strands_robots/dataset_recorder.py b/strands_robots/dataset_recorder.py
new file mode 100644
index 0000000..39039b7
--- /dev/null
+++ b/strands_robots/dataset_recorder.py
@@ -0,0 +1,516 @@
+"""LeRobotDataset recorder bridge for strands-robots.
+
+Wraps LeRobotDataset so that both strands_robots.hardware_robot and
+strands_robots.simulation (MuJoCo) can produce training-ready datasets with
+a single add_frame() call per control step.
+
+Why top-level (strands_robots.dataset_recorder) and not under simulation/?
+ Both hardware and simulation code paths need to record datasets. Placing
+ this module at the package root avoids a circular dependency
+ (simulation -> dataset_recorder -> simulation) and keeps hardware_robot
+ from reaching into the simulation sub-package.
+
+Usage:
+ recorder = DatasetRecorder.create(
+ repo_id="user/my_dataset",
+ fps=30,
+ robot_features=robot.observation_features,
+ action_features=robot.action_features,
+ task="pick up the red cube",
+ )
+ # In control loop:
+ recorder.add_frame(observation, action, task="pick up the red cube")
+ # End of episode:
+ recorder.save_episode()
+ # Optionally:
+ recorder.push_to_hub()
+"""
+
+import functools
+import logging
+import sys
+from typing import Any
+
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+# ── Lazy check for LeRobot availability ──────────────────────────────
+# We must NOT import lerobot at module level because it pulls in
+# `datasets` → `pandas`, which can crash with a numpy ABI mismatch on
+# systems where the system pandas was compiled against an older numpy
+# (e.g. JetPack / Jetson with system pandas 2.1.4 + pip numpy 2.x).
+
+
+@functools.lru_cache(maxsize=1)
+def has_lerobot_dataset() -> bool:
+ """Check if lerobot is available. Result is cached after first call."""
+ try:
+ from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: F401
+
+ return True
+ except (ImportError, ValueError, RuntimeError) as exc:
+ logger.debug("lerobot not available: %s", exc)
+ return False
+
+
+def _get_lerobot_dataset_class():
+ """Import and return LeRobotDataset class, or raise ImportError.
+
+ Supports test mocking: if ``strands_robots.dataset_recorder.LeRobotDataset``
+ has been set (by a test mock), returns that class directly.
+ """
+ # Support test mocking: check module-level overrides
+ this_module = sys.modules[__name__]
+
+ # If a test injected a mock LeRobotDataset class, use it
+ mock_cls = getattr(this_module, "LeRobotDataset", None)
+ if mock_cls is not None:
+ return mock_cls
+
+ # Actual import
+ try:
+ from lerobot.datasets.lerobot_dataset import LeRobotDataset
+
+ return LeRobotDataset
+ except (ImportError, ValueError, RuntimeError) as exc:
+ raise ImportError(
+ f"lerobot not available ({exc}). Install with: pip install lerobot\nRequired for LeRobotDataset recording."
+ ) from exc
+
+
+class DatasetRecorder:
+ """Bridge between strands-robots control loops and LeRobotDataset.
+
+ Handles the full lifecycle:
+ 1. create() — build LeRobotDataset with correct features
+ 2. add_frame() — called every control step with obs + action
+ 3. save_episode() — finalize episode (encodes video, writes parquet)
+ 4. push_to_hub() — upload to HuggingFace
+
+ Works for both real hardware (robot.py) and simulation (simulation.py).
+ """
+
+ def __init__(self, dataset, task: str = "", strict: bool = True):
+ self.dataset = dataset
+ self.default_task = task
+ self.frame_count = 0
+ self.dropped_frame_count = 0
+ self.strict = strict
+ self.episode_count = 0
+ self._closed = False
+ self._cached_state_keys: list[str] | None = None
+ self._cached_action_keys: list[str] | None = None
+
+ @classmethod
+ def create(
+ cls,
+ repo_id: str,
+ fps: int = 30,
+ robot_type: str = "unknown",
+ robot_features: dict[str, Any] | None = None,
+ action_features: dict[str, Any] | None = None,
+ camera_keys: list[str] | None = None,
+ joint_names: list[str] | None = None,
+ task: str = "",
+ root: str | None = None,
+ use_videos: bool = True,
+ vcodec: str = "libsvtav1",
+ streaming_encoding: bool = True,
+ image_writer_threads: int = 4,
+ video_backend: str = "auto",
+ ) -> "DatasetRecorder":
+ """Create a new DatasetRecorder with auto-detected features.
+
+ Args:
+ repo_id: HuggingFace dataset ID (e.g. "user/my_dataset")
+ fps: Recording frame rate
+ robot_type: Robot type string (e.g. "so100", "panda")
+ robot_features: Dict of observation feature names → types
+ (from robot.observation_features or sim joint names)
+ action_features: Dict of action feature names → types
+ camera_keys: List of camera names (images become video features)
+ joint_names: List of joint names (alternative to robot_features for sim)
+ task: Default task description
+ root: Local directory for dataset storage
+ use_videos: Encode camera frames as video (True) or keep as images
+ vcodec: Video codec (h264, hevc, libsvtav1)
+ streaming_encoding: Stream-encode video during capture
+ image_writer_threads: Threads for writing image frames
+ video_backend: Video backend for encoding ("auto" for HW encoder auto-detect)
+ """
+ # Lazy import — this is where we actually need lerobot
+ LeRobotDatasetCls = _get_lerobot_dataset_class()
+
+ # Build features dict in LeRobot format
+ features = cls._build_features(
+ robot_features=robot_features,
+ action_features=action_features,
+ camera_keys=camera_keys,
+ joint_names=joint_names,
+ use_videos=use_videos,
+ )
+
+ logger.info(f"Creating LeRobotDataset: {repo_id} @ {fps}fps, {len(features)} features, robot_type={robot_type}")
+
+ # Build kwargs, skip unsupported params for this LeRobot version
+ create_kwargs = dict(
+ repo_id=repo_id,
+ fps=fps,
+ root=root,
+ robot_type=robot_type,
+ features=features,
+ use_videos=use_videos,
+ image_writer_threads=image_writer_threads,
+ vcodec=vcodec,
+ )
+ # streaming_encoding only in newer LeRobot versions
+ import inspect
+
+ create_sig = inspect.signature(LeRobotDatasetCls.create)
+ if "streaming_encoding" in create_sig.parameters:
+ create_kwargs["streaming_encoding"] = streaming_encoding
+ if "video_backend" in create_sig.parameters:
+ create_kwargs["video_backend"] = video_backend
+ dataset = LeRobotDatasetCls.create(**create_kwargs)
+
+ recorder = cls(dataset=dataset, task=task)
+ logger.info("DatasetRecorder ready: %s", repo_id)
+ return recorder
+
+ @classmethod
+ def _build_features(
+ cls,
+ robot_features: dict | None = None,
+ action_features: dict | None = None,
+ camera_keys: list[str] | None = None,
+ joint_names: list[str] | None = None,
+ use_videos: bool = True,
+ camera_shapes: dict[str, tuple[int, int, int]] | None = None,
+ ) -> dict[str, Any]:
+ """Build LeRobot v3-compatible features dict.
+
+ LeRobot v3 features format:
+ {
+ "observation.images.camera_name": {"dtype": "video", "shape": (C, H, W), "names": [...]},
+ "observation.state": {"dtype": "float32", "shape": (N,), "names": [...]},
+ "action": {"dtype": "float32", "shape": (N,), "names": [...]},
+ }
+
+ Note: "names" must be a flat list of strings, NOT a dict like {"motors": [...]}.
+ """
+ features = {}
+
+ # --- Observation: cameras → video/image features ---
+ if camera_keys:
+ for cam_name in camera_keys:
+ key = f"observation.images.{cam_name}"
+ dtype = "video" if use_videos else "image"
+ # Per-camera shape override, default (3, 480, 640) CHW
+ shape = (3, 480, 640)
+ if camera_shapes and cam_name in camera_shapes:
+ shape = camera_shapes[cam_name]
+ features[key] = {
+ "dtype": dtype,
+ "shape": shape,
+ "names": ["channels", "height", "width"],
+ }
+
+ # --- Observation: state (joint positions) ---
+ state_dim = 0
+ state_names = []
+ if robot_features:
+ # Count scalar features (exclude cameras)
+ state_keys = [
+ k
+ for k, v in robot_features.items()
+ if not isinstance(v, dict) or v.get("dtype") not in ("image", "video")
+ ]
+ state_dim = len(state_keys)
+ state_names = state_keys
+ elif joint_names:
+ state_dim = len(joint_names)
+ state_names = list(joint_names)
+
+ if state_dim > 0:
+ features["observation.state"] = {
+ "dtype": "float32",
+ "shape": (state_dim,),
+ "names": state_names,
+ }
+
+ # --- Action ---
+ action_dim = 0
+ action_names = []
+ if action_features:
+ action_keys = [
+ k
+ for k, v in action_features.items()
+ if not isinstance(v, dict) or v.get("dtype") not in ("image", "video")
+ ]
+ action_dim = len(action_keys)
+ action_names = action_keys
+ elif joint_names:
+ action_dim = len(joint_names)
+ action_names = list(joint_names)
+ elif state_dim > 0:
+ action_dim = state_dim # Same dim as state by default
+ action_names = state_names[:]
+
+ if action_dim > 0:
+ features["action"] = {
+ "dtype": "float32",
+ "shape": (action_dim,),
+ "names": action_names[:action_dim],
+ }
+
+ return features
+
+ def add_frame(
+ self,
+ observation: dict[str, Any],
+ action: dict[str, Any],
+ task: str | None = None,
+ camera_keys: list[str] | None = None,
+ ) -> None:
+ """Add a single control-loop frame to the dataset.
+
+ This is the key method — called every step in the control loop.
+
+ Args:
+ observation: Raw observation dict from robot/sim
+ (joint_name → float, camera_name → np.ndarray)
+ action: Action dict (joint_name → float)
+ task: Task description (uses default if None)
+ camera_keys: Which keys in observation are camera images
+ """
+ if self._closed:
+ return
+
+ frame = {}
+
+ # --- Detect camera vs state keys ---
+ if camera_keys is None:
+ camera_keys = [k for k, v in observation.items() if isinstance(v, np.ndarray) and v.ndim >= 2]
+
+ state_keys = [k for k in observation.keys() if k not in camera_keys]
+
+ # --- Camera images → observation.images.{name} ---
+ for cam_key in camera_keys:
+ img = observation[cam_key]
+ if isinstance(img, np.ndarray):
+ # LeRobot expects HWC uint8 for add_frame
+ if img.dtype != np.uint8:
+ img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
+ frame[f"observation.images.{cam_key}"] = img
+
+ # --- State → observation.state (flattened vector) ---
+ # Use feature schema ordering to match the dataset schema declared in _build_features().
+ if state_keys:
+ state_vals = []
+ if self._cached_state_keys is None:
+ feat = self.dataset.features.get("observation.state", {})
+ state_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", [])
+ self._cached_state_keys = state_names if state_names else sorted(state_keys)
+
+ for k in self._cached_state_keys:
+ v = observation.get(k)
+ if v is None:
+ state_vals.append(0.0)
+ elif isinstance(v, (int, float)):
+ state_vals.append(float(v))
+ elif isinstance(v, np.ndarray) and v.ndim == 0:
+ state_vals.append(float(v))
+ elif isinstance(v, (list, np.ndarray)):
+ arr = np.asarray(v, dtype=np.float32).flatten()
+ state_vals.extend(arr.tolist())
+ if state_vals:
+ frame["observation.state"] = np.array(state_vals, dtype=np.float32)
+
+ # --- Action → flattened vector ---
+ # Use feature schema ordering for actions too.
+ if action:
+ action_vals = []
+ if self._cached_action_keys is None:
+ feat = self.dataset.features.get("action", {})
+ action_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", [])
+ self._cached_action_keys = action_names if action_names else sorted(action.keys())
+
+ for k in self._cached_action_keys:
+ v = action.get(k)
+ if v is None:
+ action_vals.append(0.0)
+ elif isinstance(v, (int, float)):
+ action_vals.append(float(v))
+ elif isinstance(v, np.ndarray) and v.ndim == 0:
+ action_vals.append(float(v))
+ elif isinstance(v, (list, np.ndarray)):
+ arr = np.asarray(v, dtype=np.float32).flatten()
+ action_vals.extend(arr.tolist())
+ if action_vals:
+ frame["action"] = np.array(action_vals, dtype=np.float32)
+
+ # --- Task (mandatory for LeRobot v3) ---
+ frame["task"] = task or self.default_task or "untitled"
+
+ # --- Reconcile camera keys between frame and feature schema ---
+ # Only strip *undeclared* cameras from the frame (keys present in obs
+ # but not registered in _build_features). This avoids LeRobot's
+ # "Extra features" error. Declared-but-missing cameras (e.g. when a
+ # render fails) are left alone — LeRobot tolerates absent columns and
+ # the episode simply won't have that camera's data.
+ declared_cam_keys = {k for k in self.dataset.features if k.startswith("observation.images.")}
+ frame_cam_keys = {k for k in frame if k.startswith("observation.images.")}
+ for extra in frame_cam_keys - declared_cam_keys:
+ del frame[extra]
+
+ # --- Add to dataset ---
+ try:
+ self.dataset.add_frame(frame)
+ self.frame_count += 1
+ except Exception as e:
+ if self.strict:
+ raise # Fail-fast per AGENTS.md convention #5
+ self.dropped_frame_count += 1
+ n = self.dropped_frame_count
+ # Log at 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, then every 1000
+ if (n & (n - 1)) == 0 or n % 1000 == 0:
+ logger.warning(
+ "add_frame failed (frame %d, dropped %d): %s",
+ self.frame_count,
+ self.dropped_frame_count,
+ e,
+ )
+
+ def save_episode(self) -> dict[str, Any]:
+ """Finalize current episode — writes parquet, encodes video, computes stats.
+
+ LeRobot v3: save_episode() takes no task argument. Tasks are stored
+ per-frame in the episode buffer via add_frame().
+
+ Returns:
+ Dict with episode info
+ """
+ if self._closed:
+ return {"status": "error", "message": "Recorder closed"}
+
+ try:
+ self.dataset.save_episode()
+ self.episode_count += 1
+ ep_frames = self.frame_count # Total frames so far
+ logger.info(f"Episode {self.episode_count} saved: {ep_frames} total frames")
+ return {
+ "status": "success",
+ "episode": self.episode_count,
+ "total_frames": ep_frames,
+ }
+ except Exception as e:
+ logger.error("save_episode failed: %s", e)
+ return {"status": "error", "message": str(e)}
+
+ def finalize(self) -> None:
+ """Finalize the dataset (close parquet writers, flush metadata)."""
+ if self._closed:
+ return
+ try:
+ self.dataset.finalize()
+ except Exception as e:
+ logger.warning("finalize warning: %s", e)
+ self._closed = True
+
+ def push_to_hub(
+ self,
+ tags: list[str] | None = None,
+ private: bool = False,
+ ) -> dict[str, Any]:
+ """Push dataset to HuggingFace Hub.
+
+ Args:
+ tags: Optional tags for the dataset
+ private: Upload as private dataset
+
+ Returns:
+ Dict with push status
+ """
+ try:
+ self.dataset.push_to_hub(tags=tags, private=private)
+ logger.info("Dataset pushed to hub: %s", self.dataset.repo_id)
+ return {
+ "status": "success",
+ "repo_id": self.dataset.repo_id,
+ "episodes": self.episode_count,
+ "frames": self.frame_count,
+ }
+ except Exception as e:
+ logger.error("push_to_hub failed: %s", e)
+ return {"status": "error", "message": str(e)}
+
+ @property
+ def repo_id(self) -> str:
+ return self.dataset.repo_id
+
+ @property
+ def root(self) -> str:
+ return str(self.dataset.root)
+
+ def __repr__(self) -> str:
+ return f"DatasetRecorder(repo_id={self.repo_id}, episodes={self.episode_count}, frames={self.frame_count})"
+
+
+# ── Shared replay-episode helpers ────────────────────────────────────
+
+
+def load_lerobot_episode(repo_id: str, episode: int = 0, root: str | None = None):
+ """Load a LeRobotDataset and resolve the frame range for an episode.
+
+ Used by strands_robots.simulation.mujoco.policy_runner for the
+ replay_episode action — the simulation backend calls this to load
+ recorded joint trajectories and replay them in MuJoCo.
+
+ Returns:
+ Tuple of (dataset, episode_start, episode_length) on success.
+
+ Raises:
+ ImportError: If lerobot is not installed.
+ ValueError: If the episode is out of range or has no frames.
+ """
+ from lerobot.datasets.lerobot_dataset import LeRobotDataset
+
+ ds = LeRobotDataset(repo_id=repo_id, root=root)
+
+ num_episodes = ds.meta.total_episodes if hasattr(ds.meta, "total_episodes") else len(ds.meta.episodes)
+ if episode >= num_episodes:
+ raise ValueError(f"Episode {episode} out of range (0-{num_episodes - 1})")
+
+ episode_start = 0
+ episode_length = 0
+ try:
+ if hasattr(ds, "episode_data_index"):
+ from_idx = ds.episode_data_index["from"][episode].item()
+ to_idx = ds.episode_data_index["to"][episode].item()
+ episode_start = from_idx
+ episode_length = to_idx - from_idx
+ else:
+ for i in range(episode):
+ ep_info = ds.meta.episodes[i] if hasattr(ds.meta, "episodes") else {}
+ episode_start += ep_info.get("length", 0)
+ ep_info = ds.meta.episodes[episode] if hasattr(ds.meta, "episodes") else {}
+ episode_length = ep_info.get("length", 0)
+ except Exception:
+ # Last resort: scan frames to find episode boundaries
+ for idx in range(len(ds)):
+ frame = ds[idx]
+ frame_ep = frame.get("episode_index", -1) if hasattr(frame, "get") else -1
+ if hasattr(frame_ep, "item"):
+ frame_ep = frame_ep.item()
+ if frame_ep == episode:
+ if episode_length == 0:
+ episode_start = idx
+ episode_length += 1
+ elif episode_length > 0:
+ break
+
+ if episode_length == 0:
+ raise ValueError(f"Episode {episode} has no frames")
+
+ return ds, episode_start, episode_length
diff --git a/strands_robots/hardware_robot.py b/strands_robots/hardware_robot.py
new file mode 100644
index 0000000..d36f83f
--- /dev/null
+++ b/strands_robots/hardware_robot.py
@@ -0,0 +1,758 @@
+#!/usr/bin/env python3
+"""
+Universal Robot Control with Policy Abstraction for Any VLA Provider
+
+This module provides a clean robot interface that works with any LeRobot-compatible
+robot and any VLA provider through the Policy abstraction.
+
+Features:
+- Async robot task execution with real-time status reporting
+- Non-blocking operations - robot moves while tool returns status
+- Stop functionality to interrupt running tasks
+- Connection state management with proper error handling
+- Policy abstraction for any VLA provider
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import threading
+import time
+from collections.abc import AsyncGenerator
+from concurrent.futures import Future, ThreadPoolExecutor
+from dataclasses import dataclass
+from enum import Enum
+from typing import TYPE_CHECKING, Any, cast
+
+from strands.tools.tools import AgentTool
+from strands.types._events import ToolResultEvent
+from strands.types.tools import ToolResult, ToolSpec, ToolUse
+
+if TYPE_CHECKING:
+ from lerobot.robots.config import RobotConfig
+ from lerobot.robots.robot import Robot as LeRobotRobot
+
+ from .policies import Policy
+
+logger = logging.getLogger(__name__)
+
+
+class TaskStatus(Enum):
+ """Robot task execution status"""
+
+ IDLE = "idle"
+ CONNECTING = "connecting"
+ RUNNING = "running"
+ COMPLETED = "completed"
+ STOPPED = "stopped"
+ ERROR = "error"
+
+
+@dataclass
+class RobotTaskState:
+ """Robot task execution state"""
+
+ status: TaskStatus = TaskStatus.IDLE
+ instruction: str = ""
+ start_time: float = 0.0
+ duration: float = 0.0
+ step_count: int = 0
+ error_message: str = ""
+ task_future: Future | None = None
+
+
+class Robot(AgentTool):
+ """Universal robot control with async task execution and status reporting."""
+
+ def __init__(
+ self,
+ tool_name: str,
+ robot: LeRobotRobot | RobotConfig | str,
+ cameras: dict[str, dict[str, Any]] | None = None,
+ action_horizon: int = 8,
+ data_config: str | Any | None = None,
+ control_frequency: float = 50.0,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize Robot with async capabilities.
+
+ Args:
+ tool_name: Name for this robot tool
+ robot: LeRobot Robot instance, RobotConfig, or robot type string
+ cameras: Camera configuration dict:
+ {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}}
+ action_horizon: Actions per inference step
+ data_config: Data configuration (for GR00T compatibility)
+ control_frequency: Control loop frequency in Hz (default: 50Hz)
+ **kwargs: Robot-specific parameters (port, etc.)
+ """
+ super().__init__()
+
+ self.tool_name_str = tool_name
+ self.action_horizon = action_horizon
+ self.data_config = data_config
+ self.control_frequency = control_frequency
+ self.action_sleep_time = 1.0 / control_frequency # Time between actions
+
+ # Task execution state
+ self._task_state = RobotTaskState()
+ self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"{tool_name}_executor")
+ self._shutdown_event = threading.Event()
+
+ # Initialize robot using lerobot's abstraction
+ self.robot = self._initialize_robot(robot, cameras, **kwargs)
+
+ logger.info(f"🤖 {tool_name} initialized with async capabilities")
+ logger.info(f"📱 Robot: {self.robot.name} (type: {getattr(self.robot, 'robot_type', 'unknown')})")
+ logger.info(f"⏱️ Control frequency: {control_frequency}Hz ({self.action_sleep_time * 1000:.1f}ms per action)")
+
+ # Get camera info if available
+ if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"):
+ cameras_list = list(self.robot.config.cameras.keys())
+ logger.info(f"📹 Cameras: {cameras_list}")
+
+ if data_config:
+ logger.info(f"⚙️ Data config: {data_config}")
+
+ def _initialize_robot(
+ self, robot: LeRobotRobot | RobotConfig | str, cameras: dict[str, dict[str, Any]] | None, **kwargs: Any
+ ) -> LeRobotRobot:
+ """Initialize LeRobot robot instance using native lerobot patterns."""
+ from lerobot.robots.config import RobotConfig
+ from lerobot.robots.robot import Robot as LeRobotRobot
+ from lerobot.robots.utils import make_robot_from_config
+
+ # Direct robot instance - use as-is
+ if isinstance(robot, LeRobotRobot):
+ return robot
+
+ # Robot config - use lerobot's factory
+ elif isinstance(robot, RobotConfig):
+ return make_robot_from_config(robot)
+
+ # Robot type string - create config and use lerobot's factory
+ elif isinstance(robot, str):
+ config = self._create_minimal_config(robot, cameras, **kwargs)
+ return make_robot_from_config(config)
+
+ else:
+ raise ValueError(
+ f"Unsupported robot type: {type(robot)}. "
+ f"Expected LeRobot Robot instance, RobotConfig, or robot type string."
+ )
+
+ def _create_minimal_config(
+ self, robot_type: str, cameras: dict[str, dict[str, Any]] | None, **kwargs: Any
+ ) -> RobotConfig:
+ """Create minimal robot config using specific robot config classes."""
+ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+
+ # Convert cameras to lerobot format
+ camera_configs = {}
+ if cameras:
+ for name, config in cameras.items():
+ if config.get("type", "opencv") == "opencv":
+ camera_configs[name] = OpenCVCameraConfig(
+ index_or_path=config["index_or_path"],
+ fps=config.get("fps", 30),
+ width=config.get("width", 640),
+ height=config.get("height", 480),
+ rotation=config.get("rotation", 0),
+ color_mode=config.get("color_mode", "rgb"),
+ )
+ else:
+ raise ValueError(f"Unsupported camera type: {config.get('type')}")
+
+ # Map robot type to specific config class
+ config_mapping = {
+ "so101_follower": ("lerobot.robots.so101_follower", "SO101FollowerConfig"),
+ "so100_follower": ("lerobot.robots.so100_follower", "SO100FollowerConfig"),
+ "bi_so100_follower": ("lerobot.robots.bi_so100_follower", "BiSO100FollowerConfig"),
+ "viperx": ("lerobot.robots.viperx", "ViperXConfig"),
+ "koch_follower": ("lerobot.robots.koch_follower", "KochFollowerConfig"),
+ # Add more as needed
+ }
+
+ if robot_type not in config_mapping:
+ raise ValueError(f"Unsupported robot type: {robot_type}. Supported types: {list(config_mapping.keys())}")
+
+ # Import specific config class dynamically
+ module_name, class_name = config_mapping[robot_type]
+ try:
+ import importlib
+
+ module = importlib.import_module(module_name)
+ ConfigClass = getattr(module, class_name)
+ except Exception as e:
+ raise ValueError(f"Failed to import {class_name} from {module_name}: {e}")
+
+ # Create config with proper parameters
+ config_data = {
+ "id": self.tool_name_str,
+ "cameras": camera_configs,
+ }
+
+ # Filter kwargs to only include supported fields for this robot type
+ # Port is common for most serial robots
+ if "port" in kwargs:
+ config_data["port"] = kwargs["port"]
+
+ # Add other common fields as needed
+ for key in ["calibration_dir", "mock", "use_degrees"]:
+ if key in kwargs:
+ config_data[key] = kwargs[key]
+
+ try:
+ return ConfigClass(**config_data)
+ except Exception as e:
+ raise ValueError(f"Failed to create {class_name} for robot type '{robot_type}': {e}. Config: {config_data}")
+
+ async def _get_policy(
+ self, policy_port: int | None = None, policy_host: str = "localhost", policy_provider: str = "groot"
+ ) -> Policy:
+ """Create policy on-the-fly from invocation parameters."""
+ from .policies import create_policy
+
+ if not policy_port:
+ raise ValueError("policy_port is required for robot operation")
+
+ policy_config = {"port": policy_port, "host": policy_host}
+
+ if self.data_config:
+ policy_config["data_config"] = self.data_config
+
+ return create_policy(policy_provider, **policy_config)
+
+ async def _connect_robot(self) -> tuple[bool, str]:
+ """Connect to robot hardware with proper error handling.
+
+ Returns:
+ tuple[bool, str]: (success, error_message) - error_message is empty on success
+ """
+ try:
+ # Import lerobot exceptions
+ from lerobot.utils.errors import DeviceAlreadyConnectedError
+
+ # Check if already connected
+ if self.robot.is_connected:
+ logger.info(f"✅ {self.robot} already connected")
+ return True, ""
+
+ logger.info(f"🔌 Connecting to {self.robot}...")
+
+ # Handle robot connection using lerobot's error handling patterns
+ try:
+ if not self.robot.is_connected:
+ await asyncio.to_thread(self.robot.connect, False) # calibrate=False
+
+ except DeviceAlreadyConnectedError:
+ # This is expected and fine - robot is already connected
+ logger.info(f"✅ {self.robot} was already connected")
+
+ except Exception as e:
+ # Check if it's the string version of "already connected" error
+ error_str = str(e).lower()
+ if "already connected" in error_str or "is already connected" in error_str:
+ logger.info(f"✅ {self.robot} connection already established")
+ else:
+ # Re-raise if it's a different error
+ raise e
+
+ # Final connection check
+ if not self.robot.is_connected:
+ error_msg = f"Failed to connect to {self.robot}"
+ logger.error(f"❌ {error_msg}")
+ return False, error_msg
+
+ # Check robot calibration
+ if hasattr(self.robot, "is_calibrated") and not self.robot.is_calibrated:
+ error_msg = (
+ f"Robot {self.robot} is not calibrated. Please calibrate the robot manually"
+ " first using LeRobot's calibration process (lerobot-calibrate)"
+ )
+ logger.error(f"❌ {error_msg}")
+ return False, error_msg
+
+ logger.info(f"✅ {self.robot} connected and ready")
+ return True, ""
+
+ except Exception as e:
+ error_msg = f"Robot connection failed: {e}. Ensure robot is calibrated and accessible on the specified port"
+ logger.error(f"❌ {error_msg}")
+ return False, error_msg
+
+ async def _initialize_policy(self, policy: Policy) -> bool:
+ """Initialize policy with robot state keys."""
+ try:
+ # Get robot state keys from observation
+ test_obs = await asyncio.to_thread(self.robot.get_observation)
+
+ # Filter out camera keys to get robot state keys
+ camera_keys = []
+ if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"):
+ camera_keys = list(self.robot.config.cameras.keys())
+
+ robot_state_keys = [k for k in test_obs.keys() if k not in camera_keys]
+
+ # Set robot state keys in policy
+ policy.set_robot_state_keys(robot_state_keys)
+ return True
+
+ except Exception as e:
+ logger.error(f"❌ Failed to initialize policy: {e}")
+ return False
+
+ async def _execute_task_async(
+ self,
+ instruction: str,
+ policy_port: int | None = None,
+ policy_host: str = "localhost",
+ policy_provider: str = "groot",
+ duration: float = 30.0,
+ ) -> None:
+ """Execute robot task in background thread (internal method)."""
+ try:
+ # Update task state
+ self._task_state.status = TaskStatus.CONNECTING
+ self._task_state.instruction = instruction
+ self._task_state.start_time = time.time()
+ self._task_state.step_count = 0
+ self._task_state.error_message = ""
+
+ # Connect to robot
+ connected, connect_error = await self._connect_robot()
+ if not connected:
+ self._task_state.status = TaskStatus.ERROR
+ self._task_state.error_message = connect_error or f"Failed to connect to {self.tool_name_str}"
+ return
+
+ # Get policy instance
+ policy_instance = await self._get_policy(policy_port, policy_host, policy_provider)
+
+ # Initialize policy with robot state keys
+ if not await self._initialize_policy(policy_instance):
+ self._task_state.status = TaskStatus.ERROR
+ self._task_state.error_message = "Failed to initialize policy"
+ return
+
+ logger.info(f"🎯 Starting task: '{instruction}' on {self.tool_name_str}")
+ logger.info(f"🧠 Using policy: {policy_provider} on {policy_host}:{policy_port}")
+
+ self._task_state.status = TaskStatus.RUNNING
+ start_time = time.time()
+
+ while (
+ time.time() - start_time < duration
+ and self._task_state.status == TaskStatus.RUNNING
+ and not self._shutdown_event.is_set()
+ ):
+ # Get observation from robot
+ observation = await asyncio.to_thread(self.robot.get_observation)
+
+ # Get actions from policy
+ robot_actions = await policy_instance.get_actions(observation, instruction)
+
+ # Execute actions from chunk with proper timing control
+ # Wait between actions for smooth execution
+ for action_dict in robot_actions[: self.action_horizon]:
+ if self._task_state.status != TaskStatus.RUNNING:
+ break
+ await asyncio.to_thread(self.robot.send_action, action_dict)
+ self._task_state.step_count += 1
+ # Wait for action to complete before sending next action
+ # Default 50Hz (0.02s)
+ await asyncio.sleep(self.action_sleep_time)
+
+ # Update final state
+ elapsed = time.time() - start_time
+ self._task_state.duration = elapsed
+
+ if self._task_state.status == TaskStatus.RUNNING:
+ self._task_state.status = TaskStatus.COMPLETED
+ logger.info(
+ f"✅ Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)"
+ )
+
+ except Exception as e:
+ logger.error(f"❌ Task execution failed: {e}")
+ self._task_state.status = TaskStatus.ERROR
+ self._task_state.error_message = str(e)
+
+ def _execute_task_sync(
+ self,
+ instruction: str,
+ policy_port: int | None = None,
+ policy_host: str = "localhost",
+ policy_provider: str = "groot",
+ duration: float = 30.0,
+ ) -> dict[str, Any]:
+ """Execute task synchronously in thread - no new event loop."""
+
+ # Import here to avoid conflicts
+ import asyncio
+
+ # Run task without creating new event loop - let it run in thread
+ async def task_runner() -> None:
+ await self._execute_task_async(instruction, policy_port, policy_host, policy_provider, duration)
+
+ # Use asyncio.run only if no loop is running, otherwise run in existing loop
+ try:
+ # Try to get the current event loop
+ asyncio.get_running_loop()
+ # If we're already in an event loop, we need to run in a thread
+ import concurrent.futures
+
+ with concurrent.futures.ThreadPoolExecutor() as exec:
+ future = exec.submit(lambda: asyncio.run(task_runner()))
+ future.result() # Wait for completion
+ except RuntimeError:
+ # No event loop running - safe to create one
+ asyncio.run(task_runner())
+
+ # Return final status
+ return {
+ "status": "success" if self._task_state.status == TaskStatus.COMPLETED else "error",
+ "content": [
+ {
+ "text": f"✅ Task: '{instruction}' - {self._task_state.status.value}\n"
+ f"🤖 Robot: {self.tool_name_str} ({self.robot})\n"
+ f"🧠 Policy: {policy_provider} on {policy_host}:{policy_port}\n"
+ f"⏱️ Duration: {self._task_state.duration:.1f}s\n"
+ f"🎯 Steps: {self._task_state.step_count}"
+ + (f"\n❌ Error: {self._task_state.error_message}" if self._task_state.error_message else "")
+ }
+ ],
+ }
+
+ def start_task(
+ self,
+ instruction: str,
+ policy_port: int | None = None,
+ policy_host: str = "localhost",
+ policy_provider: str = "groot",
+ duration: float = 30.0,
+ ) -> dict[str, Any]:
+ """Start robot task asynchronously and return immediately."""
+
+ # Check if task is already running
+ if self._task_state.status == TaskStatus.RUNNING:
+ return {
+ "status": "error",
+ "content": [{"text": f"❌ Task already running: {self._task_state.instruction}"}],
+ }
+
+ # Start task in background
+ self._task_state.task_future = self._executor.submit(
+ self._execute_task_sync, instruction, policy_port, policy_host, policy_provider, duration
+ )
+
+ return {
+ "status": "success",
+ "content": [
+ {
+ "text": f"🚀 Task started: '{instruction}'\n"
+ f"🤖 Robot: {self.tool_name_str}\n"
+ f"💡 Use action='status' to check progress\n"
+ f"💡 Use action='stop' to interrupt"
+ }
+ ],
+ }
+
+ def get_task_status(self) -> dict[str, Any]:
+ """Get current task execution status."""
+
+ # Update duration for running tasks
+ if self._task_state.status == TaskStatus.RUNNING:
+ self._task_state.duration = time.time() - self._task_state.start_time
+
+ status_text = f"📊 Robot Status: {self._task_state.status.value.upper()}\n"
+
+ if self._task_state.instruction:
+ status_text += f"🎯 Task: {self._task_state.instruction}\n"
+
+ if self._task_state.status == TaskStatus.RUNNING:
+ status_text += f"⏱️ Duration: {self._task_state.duration:.1f}s\n"
+ status_text += f"🔄 Steps: {self._task_state.step_count}\n"
+ elif self._task_state.status in [TaskStatus.COMPLETED, TaskStatus.STOPPED, TaskStatus.ERROR]:
+ status_text += f"⏱️ Total Duration: {self._task_state.duration:.1f}s\n"
+ status_text += f"🎯 Total Steps: {self._task_state.step_count}\n"
+
+ if self._task_state.error_message:
+ status_text += f"❌ Error: {self._task_state.error_message}\n"
+
+ return {
+ "status": "success",
+ "content": [{"text": status_text}],
+ }
+
+ def stop_task(self) -> dict[str, Any]:
+ """Stop currently running task."""
+
+ if self._task_state.status != TaskStatus.RUNNING:
+ return {
+ "status": "success",
+ "content": [{"text": f"💤 No task running to stop (current: {self._task_state.status.value})"}],
+ }
+
+ # Signal task to stop
+ self._task_state.status = TaskStatus.STOPPED
+
+ # Cancel future if it exists
+ if self._task_state.task_future:
+ self._task_state.task_future.cancel()
+
+ logger.info(f"🛑 Task stopped: {self._task_state.instruction}")
+
+ return {
+ "status": "success",
+ "content": [
+ {
+ "text": f"🛑 Task stopped: '{self._task_state.instruction}'\n"
+ f"⏱️ Duration: {self._task_state.duration:.1f}s\n"
+ f"🎯 Steps completed: {self._task_state.step_count}"
+ }
+ ],
+ }
+
+ @property
+ def tool_name(self) -> str:
+ return self.tool_name_str
+
+ @property
+ def tool_type(self) -> str:
+ return "robot"
+
+ @property
+ def tool_spec(self) -> ToolSpec:
+ """Get tool specification with async actions."""
+ return {
+ "name": self.tool_name_str,
+ "description": f"Universal robot control with async task execution ({self.robot}). "
+ f"Actions: execute (blocking), start (async), status, stop. "
+ f"For execute/start actions: instruction and policy_port are required. "
+ f"For status/stop actions: no additional parameters needed.",
+ "inputSchema": {
+ "json": {
+ "type": "object",
+ "properties": {
+ "action": {
+ "type": "string",
+ "description": "Action to perform: execute (blocking), start (async), status, stop",
+ "enum": ["execute", "start", "status", "stop"],
+ "default": "execute",
+ },
+ "instruction": {
+ "type": "string",
+ "description": "Natural language instruction (required for execute/start actions)",
+ },
+ "policy_port": {
+ "type": "integer",
+ "description": "Policy service port (required for execute/start actions)",
+ },
+ "policy_host": {
+ "type": "string",
+ "description": "Policy service host (default: localhost)",
+ "default": "localhost",
+ },
+ "policy_provider": {
+ "type": "string",
+ "description": "Policy provider (groot, openai, etc.)",
+ "default": "groot",
+ },
+ "duration": {
+ "type": "number",
+ "description": "Maximum execution time in seconds",
+ "default": 30.0,
+ },
+ },
+ "required": ["action"],
+ }
+ },
+ }
+
+ @staticmethod
+ def _make_tool_result(tool_use_id: str, result: dict[str, Any]) -> ToolResult:
+ """Create a ToolResult dict with the given tool_use_id merged into result."""
+ return cast(ToolResult, {"toolUseId": tool_use_id, **result})
+
+ async def stream(
+ self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any
+ ) -> AsyncGenerator[ToolResultEvent, None]:
+ """Stream robot task execution with async actions."""
+ try:
+ tool_use_id = tool_use.get("toolUseId", "")
+ input_data = tool_use.get("input", {})
+
+ action = input_data.get("action", "execute")
+
+ # Handle different actions
+ if action == "execute":
+ # Blocking execution (legacy behavior)
+ instruction = input_data.get("instruction", "")
+ policy_port = input_data.get("policy_port")
+ policy_host = input_data.get("policy_host", "localhost")
+ policy_provider = input_data.get("policy_provider", "groot")
+ duration = input_data.get("duration", 30.0)
+
+ if not instruction or not policy_port:
+ yield ToolResultEvent(
+ self._make_tool_result(
+ tool_use_id,
+ {
+ "status": "error",
+ "content": [{"text": "❌ instruction and policy_port are required for execute action"}],
+ },
+ )
+ )
+ return
+
+ # Execute task synchronously
+ task_result = self._execute_task_sync(instruction, policy_port, policy_host, policy_provider, duration)
+ yield ToolResultEvent(self._make_tool_result(tool_use_id, task_result))
+
+ elif action == "start":
+ # Asynchronous execution start
+ instruction = input_data.get("instruction", "")
+ policy_port = input_data.get("policy_port")
+ policy_host = input_data.get("policy_host", "localhost")
+ policy_provider = input_data.get("policy_provider", "groot")
+ duration = input_data.get("duration", 30.0)
+
+ if not instruction or not policy_port:
+ yield ToolResultEvent(
+ self._make_tool_result(
+ tool_use_id,
+ {
+ "status": "error",
+ "content": [{"text": "❌ instruction and policy_port are required for start action"}],
+ },
+ )
+ )
+ return
+
+ # Start task asynchronously
+ start_result = self.start_task(instruction, policy_port, policy_host, policy_provider, duration)
+ yield ToolResultEvent(self._make_tool_result(tool_use_id, start_result))
+
+ elif action == "status":
+ # Get current task status
+ status_result = self.get_task_status()
+ yield ToolResultEvent(self._make_tool_result(tool_use_id, status_result))
+
+ elif action == "stop":
+ # Stop current task
+ stop_result = self.stop_task()
+ yield ToolResultEvent(self._make_tool_result(tool_use_id, stop_result))
+
+ else:
+ yield ToolResultEvent(
+ self._make_tool_result(
+ tool_use_id,
+ {
+ "status": "error",
+ "content": [
+ {"text": f"❌ Unknown action: {action}. Valid actions: execute, start, status, stop"}
+ ],
+ },
+ )
+ )
+
+ except Exception as e:
+ logger.error(f"❌ {self.tool_name_str} error: {e}")
+ yield ToolResultEvent(
+ self._make_tool_result(
+ tool_use_id,
+ {
+ "status": "error",
+ "content": [{"text": f"❌ {self.tool_name_str} error: {str(e)}"}],
+ },
+ )
+ )
+
+ def cleanup(self) -> None:
+ """Cleanup resources and stop any running tasks."""
+ try:
+ # Signal shutdown
+ self._shutdown_event.set()
+
+ # Stop any running task
+ if self._task_state.status == TaskStatus.RUNNING:
+ self.stop_task()
+
+ # Shutdown executor
+ self._executor.shutdown(wait=True)
+
+ logger.info(f"🧹 {self.tool_name_str} cleanup completed")
+
+ except Exception as e:
+ logger.error(f"❌ Cleanup error for {self.tool_name_str}: {e}")
+
+ def __del__(self) -> None:
+ """Destructor to ensure cleanup."""
+ try:
+ self.cleanup()
+ except Exception:
+ pass # Ignore errors in destructor
+
+ async def get_status(self) -> dict[str, Any]:
+ """Get robot status including connection and task state."""
+ try:
+ # Get robot connection status
+ is_connected = self.robot.is_connected if hasattr(self.robot, "is_connected") else False
+ is_calibrated = self.robot.is_calibrated if hasattr(self.robot, "is_calibrated") else True
+
+ # Get camera status
+ camera_status = []
+ if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"):
+ for name in self.robot.config.cameras.keys():
+ camera_status.append(name)
+
+ # Build status dict
+ status_data = {
+ "robot_name": self.tool_name_str,
+ "robot_type": getattr(self.robot, "robot_type", self.robot.name),
+ "robot_info": str(self.robot),
+ "data_config": self.data_config,
+ "is_connected": is_connected,
+ "is_calibrated": is_calibrated,
+ "cameras": camera_status,
+ "task_status": self._task_state.status.value,
+ "current_instruction": self._task_state.instruction,
+ "task_duration": self._task_state.duration,
+ "task_steps": self._task_state.step_count,
+ }
+
+ # Add error info if present
+ if self._task_state.error_message:
+ status_data["task_error"] = self._task_state.error_message
+
+ return status_data
+
+ except Exception as e:
+ logger.error(f"❌ Error getting status for {self.tool_name_str}: {e}")
+ return {
+ "robot_name": self.tool_name_str,
+ "error": str(e),
+ "is_connected": False,
+ "task_status": "error",
+ }
+
+ async def stop(self) -> None:
+ """Stop robot and disconnect."""
+ try:
+ # Stop any running task first
+ if self._task_state.status == TaskStatus.RUNNING:
+ self.stop_task()
+
+ # Disconnect robot hardware
+ if hasattr(self.robot, "disconnect"):
+ await asyncio.to_thread(self.robot.disconnect)
+
+ # Cleanup resources
+ self.cleanup()
+
+ logger.info(f"🛑 {self.tool_name_str} stopped and disconnected")
+
+ except Exception as e:
+ logger.error(f"❌ Error stopping robot: {e}")
diff --git a/strands_robots/robot.py b/strands_robots/robot.py
index 33c5ba7..fa4a364 100644
--- a/strands_robots/robot.py
+++ b/strands_robots/robot.py
@@ -1,758 +1,205 @@
-#!/usr/bin/env python3
-"""
-Universal Robot Control with Policy Abstraction for Any VLA Provider
-
-This module provides a clean robot interface that works with any LeRobot-compatible
-robot and any VLA provider through the Policy abstraction.
-
-Features:
-- Async robot task execution with real-time status reporting
-- Non-blocking operations - robot moves while tool returns status
-- Stop functionality to interrupt running tasks
-- Connection state management with proper error handling
-- Policy abstraction for any VLA provider
-"""
-
-from __future__ import annotations
-
-import asyncio
-import logging
-import threading
-import time
-from collections.abc import AsyncGenerator
-from concurrent.futures import Future, ThreadPoolExecutor
-from dataclasses import dataclass
-from enum import Enum
-from typing import TYPE_CHECKING, Any, cast
-
-from strands.tools.tools import AgentTool
-from strands.types._events import ToolResultEvent
-from strands.types.tools import ToolResult, ToolSpec, ToolUse
+"""Unified Robot factory — convenience layer over ``strands_robots.simulation``
+and ``strands_robots.hardware_robot``.
-if TYPE_CHECKING:
- from lerobot.robots.config import RobotConfig
- from lerobot.robots.robot import Robot as LeRobotRobot
+Provides:
+ - ``Robot("so100")`` → returns a simulation by default (safe)
+ - ``Robot("so100", mode="real")`` → explicit real hardware
+ - ``Robot("so100", mode="auto")`` → auto-detects sim/real
+ - ``list_robots()`` → what's available
- from .policies import Policy
+Environment Variables:
+ STRANDS_ROBOT_MODE: Override mode detection ("sim" or "real").
-logger = logging.getLogger(__name__)
-
-
-class TaskStatus(Enum):
- """Robot task execution status"""
-
- IDLE = "idle"
- CONNECTING = "connecting"
- RUNNING = "running"
- COMPLETED = "completed"
- STOPPED = "stopped"
- ERROR = "error"
-
-
-@dataclass
-class RobotTaskState:
- """Robot task execution state"""
-
- status: TaskStatus = TaskStatus.IDLE
- instruction: str = ""
- start_time: float = 0.0
- duration: float = 0.0
- step_count: int = 0
- error_message: str = ""
- task_future: Future | None = None
-
-
-class Robot(AgentTool):
- """Universal robot control with async task execution and status reporting."""
-
- def __init__(
- self,
- tool_name: str,
- robot: LeRobotRobot | RobotConfig | str,
- cameras: dict[str, dict[str, Any]] | None = None,
- action_horizon: int = 8,
- data_config: str | Any | None = None,
- control_frequency: float = 50.0,
- **kwargs,
- ):
- """Initialize Robot with async capabilities.
-
- Args:
- tool_name: Name for this robot tool
- robot: LeRobot Robot instance, RobotConfig, or robot type string
- cameras: Camera configuration dict:
- {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}}
- action_horizon: Actions per inference step
- data_config: Data configuration (for GR00T compatibility)
- control_frequency: Control loop frequency in Hz (default: 50Hz)
- **kwargs: Robot-specific parameters (port, etc.)
- """
- super().__init__()
-
- self.tool_name_str = tool_name
- self.action_horizon = action_horizon
- self.data_config = data_config
- self.control_frequency = control_frequency
- self.action_sleep_time = 1.0 / control_frequency # Time between actions
-
- # Task execution state
- self._task_state = RobotTaskState()
- self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"{tool_name}_executor")
- self._shutdown_event = threading.Event()
-
- # Initialize robot using lerobot's abstraction
- self.robot = self._initialize_robot(robot, cameras, **kwargs)
-
- logger.info(f"🤖 {tool_name} initialized with async capabilities")
- logger.info(f"📱 Robot: {self.robot.name} (type: {getattr(self.robot, 'robot_type', 'unknown')})")
- logger.info(f"⏱️ Control frequency: {control_frequency}Hz ({self.action_sleep_time * 1000:.1f}ms per action)")
-
- # Get camera info if available
- if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"):
- cameras_list = list(self.robot.config.cameras.keys())
- logger.info(f"📹 Cameras: {cameras_list}")
-
- if data_config:
- logger.info(f"⚙️ Data config: {data_config}")
-
- def _initialize_robot(
- self, robot: LeRobotRobot | RobotConfig | str, cameras: dict[str, dict[str, Any]] | None, **kwargs
- ) -> LeRobotRobot:
- """Initialize LeRobot robot instance using native lerobot patterns."""
- from lerobot.robots.config import RobotConfig
- from lerobot.robots.robot import Robot as LeRobotRobot
- from lerobot.robots.utils import make_robot_from_config
-
- # Direct robot instance - use as-is
- if isinstance(robot, LeRobotRobot):
- return robot
-
- # Robot config - use lerobot's factory
- elif isinstance(robot, RobotConfig):
- return make_robot_from_config(robot)
-
- # Robot type string - create config and use lerobot's factory
- elif isinstance(robot, str):
- config = self._create_minimal_config(robot, cameras, **kwargs)
- return make_robot_from_config(config)
-
- else:
- raise ValueError(
- f"Unsupported robot type: {type(robot)}. "
- f"Expected LeRobot Robot instance, RobotConfig, or robot type string."
- )
+Examples::
- def _create_minimal_config(
- self, robot_type: str, cameras: dict[str, dict[str, Any]] | None, **kwargs
- ) -> RobotConfig:
- """Create minimal robot config using specific robot config classes."""
- from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-
- # Convert cameras to lerobot format
- camera_configs = {}
- if cameras:
- for name, config in cameras.items():
- if config.get("type", "opencv") == "opencv":
- camera_configs[name] = OpenCVCameraConfig(
- index_or_path=config["index_or_path"],
- fps=config.get("fps", 30),
- width=config.get("width", 640),
- height=config.get("height", 480),
- rotation=config.get("rotation", 0),
- color_mode=config.get("color_mode", "rgb"),
- )
- else:
- raise ValueError(f"Unsupported camera type: {config.get('type')}")
-
- # Map robot type to specific config class
- config_mapping = {
- "so101_follower": ("lerobot.robots.so101_follower", "SO101FollowerConfig"),
- "so100_follower": ("lerobot.robots.so100_follower", "SO100FollowerConfig"),
- "bi_so100_follower": ("lerobot.robots.bi_so100_follower", "BiSO100FollowerConfig"),
- "viperx": ("lerobot.robots.viperx", "ViperXConfig"),
- "koch_follower": ("lerobot.robots.koch_follower", "KochFollowerConfig"),
- # Add more as needed
- }
-
- if robot_type not in config_mapping:
- raise ValueError(f"Unsupported robot type: {robot_type}. Supported types: {list(config_mapping.keys())}")
+ # Default: simulation (safe — no physical hardware interaction)
+ sim = Robot("so100")
- # Import specific config class dynamically
- module_name, class_name = config_mapping[robot_type]
- try:
- import importlib
+ # Explicit real hardware
+ hw = Robot("so100", mode="real", cameras={...})
- module = importlib.import_module(module_name)
- ConfigClass = getattr(module, class_name)
- except Exception as e:
- raise ValueError(f"Failed to import {class_name} from {module_name}: {e}")
+ # Auto-detect (probes USB for servo controllers)
+ robot = Robot("so100", mode="auto")
- # Create config with proper parameters
- config_data = {
- "id": self.tool_name_str,
- "cameras": camera_configs,
- }
+ # With custom URDF/MJCF path
+ sim = Robot("my_arm", urdf_path="/path/to/robot.xml")
- # Filter kwargs to only include supported fields for this robot type
- # Port is common for most serial robots
- if "port" in kwargs:
- config_data["port"] = kwargs["port"]
+Future (not yet implemented)::
- # Add other common fields as needed
- for key in ["calibration_dir", "mock", "use_degrees"]:
- if key in kwargs:
- config_data[key] = kwargs[key]
-
- try:
- return ConfigClass(**config_data)
- except Exception as e:
- raise ValueError(f"Failed to create {class_name} for robot type '{robot_type}': {e}. Config: {config_data}")
+ sim = Robot("unitree_go2", backend="isaac", num_envs=4096)
+ sim = Robot("so100", backend="newton", num_envs=4096)
+"""
- async def _get_policy(
- self, policy_port: int | None = None, policy_host: str = "localhost", policy_provider: str = "groot"
- ) -> Policy:
- """Create policy on-the-fly from invocation parameters."""
- from .policies import create_policy
+import logging
+import os
+from typing import Any
- if not policy_port:
- raise ValueError("policy_port is required for robot operation")
+from strands_robots.registry import (
+ get_hardware_type,
+ has_hardware,
+ resolve_name,
+)
- policy_config = {"port": policy_port, "host": policy_host}
+logger = logging.getLogger(__name__)
- if self.data_config:
- policy_config["data_config"] = self.data_config
- return create_policy(policy_provider, **policy_config)
+def _auto_detect_mode(canonical: str) -> str:
+ """Auto-detect sim vs real mode.
- async def _connect_robot(self) -> tuple[bool, str]:
- """Connect to robot hardware with proper error handling.
+ Priority:
+ 1. ``STRANDS_ROBOT_MODE`` env var (explicit override)
+ 2. Robot-specific USB detection (Feetech/Dynamixel servo controllers)
+ 3. Default to sim (safest — never accidentally send commands to hardware)
+ """
+ env_mode = os.getenv("STRANDS_ROBOT_MODE", "").lower()
+ if env_mode in ("sim", "real"):
+ return env_mode
- Returns:
- tuple[bool, str]: (success, error_message) - error_message is empty on success
- """
+ # Only probe USB if the robot actually has hardware support
+ if has_hardware(canonical):
try:
- # Import lerobot exceptions
- from lerobot.utils.errors import DeviceAlreadyConnectedError
-
- # Check if already connected
- if self.robot.is_connected:
- logger.info(f"✅ {self.robot} already connected")
- return True, ""
-
- logger.info(f"🔌 Connecting to {self.robot}...")
-
- # Handle robot connection using lerobot's error handling patterns
- try:
- if not self.robot.is_connected:
- await asyncio.to_thread(self.robot.connect, False) # calibrate=False
-
- except DeviceAlreadyConnectedError:
- # This is expected and fine - robot is already connected
- logger.info(f"✅ {self.robot} was already connected")
-
- except Exception as e:
- # Check if it's the string version of "already connected" error
- error_str = str(e).lower()
- if "already connected" in error_str or "is already connected" in error_str:
- logger.info(f"✅ {self.robot} connection already established")
- else:
- # Re-raise if it's a different error
- raise e
-
- # Final connection check
- if not self.robot.is_connected:
- error_msg = f"Failed to connect to {self.robot}"
- logger.error(f"❌ {error_msg}")
- return False, error_msg
-
- # Check robot calibration
- if hasattr(self.robot, "is_calibrated") and not self.robot.is_calibrated:
- error_msg = (
- f"Robot {self.robot} is not calibrated. Please calibrate the robot manually"
- " first using LeRobot's calibration process (lerobot-calibrate)"
+ import serial.tools.list_ports
+
+ ports = list(serial.tools.list_ports.comports())
+ servo_keywords = ["feetech", "dynamixel", "sts3215", "xl430", "xl330"]
+ exclude = ["bluetooth", "internal", "debug", "apple", "modem"]
+ robot_ports = [
+ p
+ for p in ports
+ if any(
+ kw in ((p.description or "") + (getattr(p, "manufacturer", None) or "")).lower()
+ for kw in servo_keywords
)
- logger.error(f"❌ {error_msg}")
- return False, error_msg
-
- logger.info(f"✅ {self.robot} connected and ready")
- return True, ""
-
- except Exception as e:
- error_msg = f"Robot connection failed: {e}. Ensure robot is calibrated and accessible on the specified port"
- logger.error(f"❌ {error_msg}")
- return False, error_msg
-
- async def _initialize_policy(self, policy: Policy) -> bool:
- """Initialize policy with robot state keys."""
- try:
- # Get robot state keys from observation
- test_obs = await asyncio.to_thread(self.robot.get_observation)
-
- # Filter out camera keys to get robot state keys
- camera_keys = []
- if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"):
- camera_keys = list(self.robot.config.cameras.keys())
-
- robot_state_keys = [k for k in test_obs.keys() if k not in camera_keys]
-
- # Set robot state keys in policy
- policy.set_robot_state_keys(robot_state_keys)
- return True
-
- except Exception as e:
- logger.error(f"❌ Failed to initialize policy: {e}")
- return False
-
- async def _execute_task_async(
- self,
- instruction: str,
- policy_port: int | None = None,
- policy_host: str = "localhost",
- policy_provider: str = "groot",
- duration: float = 30.0,
- ) -> None:
- """Execute robot task in background thread (internal method)."""
- try:
- # Update task state
- self._task_state.status = TaskStatus.CONNECTING
- self._task_state.instruction = instruction
- self._task_state.start_time = time.time()
- self._task_state.step_count = 0
- self._task_state.error_message = ""
-
- # Connect to robot
- connected, connect_error = await self._connect_robot()
- if not connected:
- self._task_state.status = TaskStatus.ERROR
- self._task_state.error_message = connect_error or f"Failed to connect to {self.tool_name_str}"
- return
-
- # Get policy instance
- policy_instance = await self._get_policy(policy_port, policy_host, policy_provider)
-
- # Initialize policy with robot state keys
- if not await self._initialize_policy(policy_instance):
- self._task_state.status = TaskStatus.ERROR
- self._task_state.error_message = "Failed to initialize policy"
- return
-
- logger.info(f"🎯 Starting task: '{instruction}' on {self.tool_name_str}")
- logger.info(f"🧠 Using policy: {policy_provider} on {policy_host}:{policy_port}")
-
- self._task_state.status = TaskStatus.RUNNING
- start_time = time.time()
-
- while (
- time.time() - start_time < duration
- and self._task_state.status == TaskStatus.RUNNING
- and not self._shutdown_event.is_set()
- ):
- # Get observation from robot
- observation = await asyncio.to_thread(self.robot.get_observation)
-
- # Get actions from policy
- robot_actions = await policy_instance.get_actions(observation, instruction)
-
- # Execute actions from chunk with proper timing control
- # Wait between actions for smooth execution
- for action_dict in robot_actions[: self.action_horizon]:
- if self._task_state.status != TaskStatus.RUNNING:
- break
- await asyncio.to_thread(self.robot.send_action, action_dict)
- self._task_state.step_count += 1
- # Wait for action to complete before sending next action
- # Default 50Hz (0.02s)
- await asyncio.sleep(self.action_sleep_time)
-
- # Update final state
- elapsed = time.time() - start_time
- self._task_state.duration = elapsed
-
- if self._task_state.status == TaskStatus.RUNNING:
- self._task_state.status = TaskStatus.COMPLETED
+ and not any(s in (p.description or "").lower() for s in exclude)
+ ]
+ if robot_ports:
logger.info(
- f"✅ Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)"
+ "Auto-detected robot hardware: %s",
+ [p.device for p in robot_ports],
)
+ return "real"
+ except (ImportError, OSError): # USB probing may fail with OSError on permission/device issues
+ pass
+
+ return "sim"
+
+
+def Robot(
+ name: str,
+ mode: str = "sim",
+ backend: str = "mujoco",
+ urdf_path: str | None = None,
+ cameras: dict[str, dict[str, Any]] | None = None,
+ position: list[float] | None = None,
+ **kwargs: Any,
+) -> Any:
+ """Create a robot — returns a Simulation or HardwareRobot instance.
+
+ This is a convenience factory, NOT a wrapper class. You get the real
+ backend instance back — with full access to all its methods.
+
+ Defaults to simulation mode so that ``Robot("so100")`` never
+ accidentally sends commands to physical hardware. Use
+ ``mode="real"`` to explicitly opt into hardware control.
+
+ Args:
+ name: Robot name ("so100", "aloha", "unitree_g1", "panda", ...)
+ Accepts any alias defined in ``registry/robots.json``.
+ mode: "sim" (default — safe), "real" (explicit hardware), or
+ "auto" (probes USB for servo controllers, falls back to sim).
+ backend: Simulation backend — currently only "mujoco" (CPU).
+ Future: "isaac" (GPU), "newton" (GPU).
+ urdf_path: Explicit path to URDF/MJCF file. If not provided,
+ resolved via ``strands_robots.simulation.model_registry``
+ (asset manager or ``STRANDS_ASSETS_DIR`` search paths).
+ cameras: Camera config for real hardware. Example::
+
+ {"wrist": {"type": "opencv", "index_or_path": "/dev/video0", "fps": 30}}
+
+ position: Robot position in sim world [x, y, z].
+ **kwargs: Forwarded to the underlying backend constructor.
+
+ Returns:
+ ``strands_robots.simulation.Simulation`` (sim) or
+ ``strands_robots.hardware_robot.Robot`` (real hardware).
+
+ Raises:
+ RuntimeError: If the sim world or robot fails to initialize.
+ NotImplementedError: If an unimplemented backend is requested.
+
+ Examples::
+
+ # Simulation (default — safe)
+ sim = Robot("so100")
+
+ # Explicit MJCF model path
+ sim = Robot("my_arm", urdf_path="path/to/robot.xml")
+
+ # Real hardware (explicit opt-in)
+ hw = Robot("so100", mode="real", cameras={...})
+
+ # Auto-detect (probes USB, falls back to sim)
+ robot = Robot("so100", mode="auto")
+
+ # The 5-line promise
+ from strands_robots import Robot
+ from strands import Agent
+ robot = Robot("so100")
+ agent = Agent(tools=[robot])
+ agent("Pick up the red cube")
+ """
+ canonical = resolve_name(name)
+
+ if mode == "auto":
+ mode = _auto_detect_mode(canonical)
+
+ # ── Simulation ──
+ if mode == "sim":
+ if backend != "mujoco":
+ raise NotImplementedError(
+ f"Backend {backend!r} is not yet implemented. "
+ f"Currently supported: 'mujoco'. "
+ f"Isaac and Newton backends are on the roadmap."
+ )
- except Exception as e:
- logger.error(f"❌ Task execution failed: {e}")
- self._task_state.status = TaskStatus.ERROR
- self._task_state.error_message = str(e)
-
- def _execute_task_sync(
- self,
- instruction: str,
- policy_port: int | None = None,
- policy_host: str = "localhost",
- policy_provider: str = "groot",
- duration: float = 30.0,
- ) -> dict[str, Any]:
- """Execute task synchronously in thread - no new event loop."""
-
- # Import here to avoid conflicts
- import asyncio
-
- # Run task without creating new event loop - let it run in thread
- async def task_runner():
- await self._execute_task_async(instruction, policy_port, policy_host, policy_provider, duration)
-
- # Use asyncio.run only if no loop is running, otherwise run in existing loop
- try:
- # Try to get the current event loop
- asyncio.get_running_loop()
- # If we're already in an event loop, we need to run in a thread
- import concurrent.futures
-
- with concurrent.futures.ThreadPoolExecutor() as exec:
- future = exec.submit(lambda: asyncio.run(task_runner()))
- future.result() # Wait for completion
- except RuntimeError:
- # No event loop running - safe to create one
- asyncio.run(task_runner())
-
- # Return final status
- return {
- "status": "success" if self._task_state.status == TaskStatus.COMPLETED else "error",
- "content": [
- {
- "text": f"✅ Task: '{instruction}' - {self._task_state.status.value}\n"
- f"🤖 Robot: {self.tool_name_str} ({self.robot})\n"
- f"🧠 Policy: {policy_provider} on {policy_host}:{policy_port}\n"
- f"⏱️ Duration: {self._task_state.duration:.1f}s\n"
- f"🎯 Steps: {self._task_state.step_count}"
- + (f"\n❌ Error: {self._task_state.error_message}" if self._task_state.error_message else "")
- }
- ],
- }
+ from strands_robots.simulation import Simulation
- def start_task(
- self,
- instruction: str,
- policy_port: int | None = None,
- policy_host: str = "localhost",
- policy_provider: str = "groot",
- duration: float = 30.0,
- ) -> dict[str, Any]:
- """Start robot task asynchronously and return immediately."""
-
- # Check if task is already running
- if self._task_state.status == TaskStatus.RUNNING:
- return {
- "status": "error",
- "content": [{"text": f"❌ Task already running: {self._task_state.instruction}"}],
- }
-
- # Start task in background
- self._task_state.task_future = self._executor.submit(
- self._execute_task_sync, instruction, policy_port, policy_host, policy_provider, duration
+ sim = Simulation(
+ tool_name=f"{canonical}_sim",
+ **kwargs,
)
+ sim._dispatch_action("create_world", {})
- return {
- "status": "success",
- "content": [
- {
- "text": f"🚀 Task started: '{instruction}'\n"
- f"🤖 Robot: {self.tool_name_str}\n"
- f"💡 Use action='status' to check progress\n"
- f"💡 Use action='stop' to interrupt"
- }
- ],
+ add_robot_params: dict[str, Any] = {
+ "robot_name": canonical,
+ "data_config": canonical,
+ "position": position or [0.0, 0.0, 0.0],
}
+ if urdf_path:
+ add_robot_params["urdf_path"] = urdf_path
+
+ result = sim._dispatch_action("add_robot", add_robot_params)
+ if result.get("status") == "error":
+ sim.destroy() # Clean up partial initialization (executor, temp dir, MuJoCo world)
+ content = result.get("content", [])
+ msg = content[0].get("text", str(result)) if content else str(result)
+ raise RuntimeError(f"Failed to create sim robot '{canonical}': {msg}")
+ return sim
+
+ # ── Real hardware (explicit opt-in) ──
+ elif mode == "real":
+ from strands_robots.hardware_robot import Robot as HardwareRobot
+
+ real_type = get_hardware_type(canonical) or canonical
+ return HardwareRobot(
+ tool_name=canonical,
+ robot=real_type,
+ cameras=cameras,
+ **kwargs,
+ )
- def get_task_status(self) -> dict[str, Any]:
- """Get current task execution status."""
-
- # Update duration for running tasks
- if self._task_state.status == TaskStatus.RUNNING:
- self._task_state.duration = time.time() - self._task_state.start_time
-
- status_text = f"📊 Robot Status: {self._task_state.status.value.upper()}\n"
-
- if self._task_state.instruction:
- status_text += f"🎯 Task: {self._task_state.instruction}\n"
-
- if self._task_state.status == TaskStatus.RUNNING:
- status_text += f"⏱️ Duration: {self._task_state.duration:.1f}s\n"
- status_text += f"🔄 Steps: {self._task_state.step_count}\n"
- elif self._task_state.status in [TaskStatus.COMPLETED, TaskStatus.STOPPED, TaskStatus.ERROR]:
- status_text += f"⏱️ Total Duration: {self._task_state.duration:.1f}s\n"
- status_text += f"🎯 Total Steps: {self._task_state.step_count}\n"
-
- if self._task_state.error_message:
- status_text += f"❌ Error: {self._task_state.error_message}\n"
-
- return {
- "status": "success",
- "content": [{"text": status_text}],
- }
-
- def stop_task(self) -> dict[str, Any]:
- """Stop currently running task."""
-
- if self._task_state.status != TaskStatus.RUNNING:
- return {
- "status": "success",
- "content": [{"text": f"💤 No task running to stop (current: {self._task_state.status.value})"}],
- }
-
- # Signal task to stop
- self._task_state.status = TaskStatus.STOPPED
-
- # Cancel future if it exists
- if self._task_state.task_future:
- self._task_state.task_future.cancel()
-
- logger.info(f"🛑 Task stopped: {self._task_state.instruction}")
-
- return {
- "status": "success",
- "content": [
- {
- "text": f"🛑 Task stopped: '{self._task_state.instruction}'\n"
- f"⏱️ Duration: {self._task_state.duration:.1f}s\n"
- f"🎯 Steps completed: {self._task_state.step_count}"
- }
- ],
- }
-
- @property
- def tool_name(self) -> str:
- return self.tool_name_str
-
- @property
- def tool_type(self) -> str:
- return "robot"
-
- @property
- def tool_spec(self) -> ToolSpec:
- """Get tool specification with async actions."""
- return {
- "name": self.tool_name_str,
- "description": f"Universal robot control with async task execution ({self.robot}). "
- f"Actions: execute (blocking), start (async), status, stop. "
- f"For execute/start actions: instruction and policy_port are required. "
- f"For status/stop actions: no additional parameters needed.",
- "inputSchema": {
- "json": {
- "type": "object",
- "properties": {
- "action": {
- "type": "string",
- "description": "Action to perform: execute (blocking), start (async), status, stop",
- "enum": ["execute", "start", "status", "stop"],
- "default": "execute",
- },
- "instruction": {
- "type": "string",
- "description": "Natural language instruction (required for execute/start actions)",
- },
- "policy_port": {
- "type": "integer",
- "description": "Policy service port (required for execute/start actions)",
- },
- "policy_host": {
- "type": "string",
- "description": "Policy service host (default: localhost)",
- "default": "localhost",
- },
- "policy_provider": {
- "type": "string",
- "description": "Policy provider (groot, openai, etc.)",
- "default": "groot",
- },
- "duration": {
- "type": "number",
- "description": "Maximum execution time in seconds",
- "default": 30.0,
- },
- },
- "required": ["action"],
- }
- },
- }
-
- @staticmethod
- def _make_tool_result(tool_use_id: str, result: dict[str, Any]) -> ToolResult:
- """Create a ToolResult dict with the given tool_use_id merged into result."""
- return cast(ToolResult, {"toolUseId": tool_use_id, **result})
-
- async def stream(
- self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any
- ) -> AsyncGenerator[ToolResultEvent, None]:
- """Stream robot task execution with async actions."""
- try:
- tool_use_id = tool_use.get("toolUseId", "")
- input_data = tool_use.get("input", {})
-
- action = input_data.get("action", "execute")
-
- # Handle different actions
- if action == "execute":
- # Blocking execution (legacy behavior)
- instruction = input_data.get("instruction", "")
- policy_port = input_data.get("policy_port")
- policy_host = input_data.get("policy_host", "localhost")
- policy_provider = input_data.get("policy_provider", "groot")
- duration = input_data.get("duration", 30.0)
-
- if not instruction or not policy_port:
- yield ToolResultEvent(
- self._make_tool_result(
- tool_use_id,
- {
- "status": "error",
- "content": [{"text": "❌ instruction and policy_port are required for execute action"}],
- },
- )
- )
- return
-
- # Execute task synchronously
- task_result = self._execute_task_sync(instruction, policy_port, policy_host, policy_provider, duration)
- yield ToolResultEvent(self._make_tool_result(tool_use_id, task_result))
-
- elif action == "start":
- # Asynchronous execution start
- instruction = input_data.get("instruction", "")
- policy_port = input_data.get("policy_port")
- policy_host = input_data.get("policy_host", "localhost")
- policy_provider = input_data.get("policy_provider", "groot")
- duration = input_data.get("duration", 30.0)
-
- if not instruction or not policy_port:
- yield ToolResultEvent(
- self._make_tool_result(
- tool_use_id,
- {
- "status": "error",
- "content": [{"text": "❌ instruction and policy_port are required for start action"}],
- },
- )
- )
- return
-
- # Start task asynchronously
- start_result = self.start_task(instruction, policy_port, policy_host, policy_provider, duration)
- yield ToolResultEvent(self._make_tool_result(tool_use_id, start_result))
-
- elif action == "status":
- # Get current task status
- status_result = self.get_task_status()
- yield ToolResultEvent(self._make_tool_result(tool_use_id, status_result))
-
- elif action == "stop":
- # Stop current task
- stop_result = self.stop_task()
- yield ToolResultEvent(self._make_tool_result(tool_use_id, stop_result))
-
- else:
- yield ToolResultEvent(
- self._make_tool_result(
- tool_use_id,
- {
- "status": "error",
- "content": [
- {"text": f"❌ Unknown action: {action}. Valid actions: execute, start, status, stop"}
- ],
- },
- )
- )
-
- except Exception as e:
- logger.error(f"❌ {self.tool_name_str} error: {e}")
- yield ToolResultEvent(
- self._make_tool_result(
- tool_use_id,
- {
- "status": "error",
- "content": [{"text": f"❌ {self.tool_name_str} error: {str(e)}"}],
- },
- )
- )
-
- def cleanup(self):
- """Cleanup resources and stop any running tasks."""
- try:
- # Signal shutdown
- self._shutdown_event.set()
-
- # Stop any running task
- if self._task_state.status == TaskStatus.RUNNING:
- self.stop_task()
-
- # Shutdown executor
- self._executor.shutdown(wait=True)
-
- logger.info(f"🧹 {self.tool_name_str} cleanup completed")
-
- except Exception as e:
- logger.error(f"❌ Cleanup error for {self.tool_name_str}: {e}")
-
- def __del__(self):
- """Destructor to ensure cleanup."""
- try:
- self.cleanup()
- except Exception:
- pass # Ignore errors in destructor
-
- async def get_status(self) -> dict[str, Any]:
- """Get robot status including connection and task state."""
- try:
- # Get robot connection status
- is_connected = self.robot.is_connected if hasattr(self.robot, "is_connected") else False
- is_calibrated = self.robot.is_calibrated if hasattr(self.robot, "is_calibrated") else True
-
- # Get camera status
- camera_status = []
- if hasattr(self.robot, "config") and hasattr(self.robot.config, "cameras"):
- for name in self.robot.config.cameras.keys():
- camera_status.append(name)
-
- # Build status dict
- status_data = {
- "robot_name": self.tool_name_str,
- "robot_type": getattr(self.robot, "robot_type", self.robot.name),
- "robot_info": str(self.robot),
- "data_config": self.data_config,
- "is_connected": is_connected,
- "is_calibrated": is_calibrated,
- "cameras": camera_status,
- "task_status": self._task_state.status.value,
- "current_instruction": self._task_state.instruction,
- "task_duration": self._task_state.duration,
- "task_steps": self._task_state.step_count,
- }
-
- # Add error info if present
- if self._task_state.error_message:
- status_data["task_error"] = self._task_state.error_message
-
- return status_data
-
- except Exception as e:
- logger.error(f"❌ Error getting status for {self.tool_name_str}: {e}")
- return {
- "robot_name": self.tool_name_str,
- "error": str(e),
- "is_connected": False,
- "task_status": "error",
- }
-
- async def stop(self):
- """Stop robot and disconnect."""
- try:
- # Stop any running task first
- if self._task_state.status == TaskStatus.RUNNING:
- self.stop_task()
-
- # Disconnect robot hardware
- if hasattr(self.robot, "disconnect"):
- await asyncio.to_thread(self.robot.disconnect)
-
- # Cleanup resources
- self.cleanup()
+ else:
+ raise ValueError(f"Invalid mode {mode!r}. Choose 'sim', 'real', or 'auto'.")
- logger.info(f"🛑 {self.tool_name_str} stopped and disconnected")
- except Exception as e:
- logger.error(f"❌ Error stopping robot: {e}")
+__all__ = ["Robot"]
diff --git a/strands_robots/simulation/__init__.py b/strands_robots/simulation/__init__.py
index d9674a9..4aea6f8 100644
--- a/strands_robots/simulation/__init__.py
+++ b/strands_robots/simulation/__init__.py
@@ -7,9 +7,19 @@
├── base.py ← SimEngine ABC
├── factory.py ← create_simulation() + backend registration
├── models.py ← shared dataclasses (SimWorld, SimRobot, ...)
- └── model_registry.py ← URDF/MJCF resolution (shared across backends)
-
- # MuJoCo backend added in subsequent PRs.
+ ├── model_registry.py ← URDF/MJCF resolution (shared across backends)
+ └── mujoco/ ← MuJoCo CPU backend
+ ├── __init__.py
+ ├── backend.py ← lazy mujoco import + GL config
+ ├── mjcf_builder.py ← MJCF XML builder
+ ├── physics.py ← advanced physics (raycasting, jacobians, forces)
+ ├── scene_ops.py ← XML round-trip inject/eject
+ ├── rendering.py ← render RGB/depth, observations
+ ├── policy_runner.py ← run_policy, eval_policy, replay
+ ├── randomization.py ← domain randomization
+ ├── recording.py ← LeRobotDataset recording
+ ├── tool_spec.json ← AgentTool input schema
+ └── simulation.py ← Simulation (AgentTool orchestrator)
Usage::
@@ -62,10 +72,15 @@
TrajectoryStep,
)
-# --- Heavy imports (lazy — loaded when mujoco backend is available) ---
-# MuJoCo-specific lazy imports will be added when the mujoco/ subpackage
-# is introduced. For now, only the lightweight foundation is available.
-_LAZY_IMPORTS: dict[str, tuple[str, str]] = {}
+# --- Heavy imports (lazy — need strands SDK + mujoco) ---
+_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
+ "Simulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"),
+ "MuJoCoSimulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"),
+ "MJCFBuilder": ("strands_robots.simulation.mujoco.mjcf_builder", "MJCFBuilder"),
+ "_configure_gl_backend": ("strands_robots.simulation.mujoco.backend", "_configure_gl_backend"),
+ "_ensure_mujoco": ("strands_robots.simulation.mujoco.backend", "_ensure_mujoco"),
+ "_is_headless": ("strands_robots.simulation.mujoco.backend", "_is_headless"),
+}
__all__ = [
@@ -75,9 +90,9 @@
"create_simulation",
"list_backends",
"register_backend",
- # Default backend alias (available when mujoco backend is installed)
- # "Simulation",
- # "MuJoCoSimulation",
+ # Default backend alias
+ "Simulation",
+ "MuJoCoSimulation",
# Shared dataclasses
"SimStatus",
"SimRobot",
@@ -85,8 +100,8 @@
"SimCamera",
"SimWorld",
"TrajectoryStep",
- # MuJoCo builder (available when mujoco backend is installed)
- # "MJCFBuilder",
+ # MuJoCo builder
+ "MJCFBuilder",
# Model registry
"register_urdf",
"resolve_model",
diff --git a/strands_robots/simulation/mujoco/__init__.py b/strands_robots/simulation/mujoco/__init__.py
new file mode 100644
index 0000000..03c6a03
--- /dev/null
+++ b/strands_robots/simulation/mujoco/__init__.py
@@ -0,0 +1,32 @@
+"""MuJoCo simulation backend for strands-robots.
+
+CPU-based physics with offscreen rendering. No GPU required.
+Supports URDF/MJCF loading, multi-robot scenes, policy execution,
+domain randomization, and LeRobotDataset recording.
+
+Usage::
+
+ from strands_robots.simulation.mujoco import MuJoCoSimulation
+
+ sim = MuJoCoSimulation(tool_name="my_sim")
+ sim.create_world()
+ sim.add_robot("so100", data_config="so100")
+ sim.run_policy("so100", policy_provider="mock", instruction="wave")
+
+Or via the top-level alias::
+
+ from strands_robots.simulation import Simulation # → MuJoCoSimulation
+"""
+
+__all__ = [
+ "MuJoCoSimulation",
+]
+
+
+def __getattr__(name: str) -> "type":
+ if name == "MuJoCoSimulation":
+ from strands_robots.simulation.mujoco.simulation import Simulation as _Sim
+
+ globals()["MuJoCoSimulation"] = _Sim
+ return _Sim
+ raise AttributeError(f"module 'strands_robots.simulation.mujoco' has no attribute {name!r}")
diff --git a/strands_robots/simulation/mujoco/backend.py b/strands_robots/simulation/mujoco/backend.py
new file mode 100644
index 0000000..9c0873d
--- /dev/null
+++ b/strands_robots/simulation/mujoco/backend.py
@@ -0,0 +1,138 @@
+"""MuJoCo lazy import and GL backend configuration."""
+
+import ctypes
+import logging
+import os
+import sys
+from typing import Any
+
+logger = logging.getLogger(__name__)
+
+_mujoco = None
+_mujoco_viewer = None
+
+
+def _is_headless() -> bool:
+ """Detect if running in a headless environment (no display server).
+
+ Returns True on Linux when no DISPLAY or WAYLAND_DISPLAY is set,
+ which means GLFW-based rendering will fail.
+
+ Windows and macOS are always False because MuJoCo uses native
+ windowing backends (WGL on Windows, CGL on macOS) that support
+ offscreen rendering without X11/Wayland. The EGL/OSMesa fallback
+ is Linux-specific.
+ """
+ if sys.platform != "linux":
+ return False
+ if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
+ return False
+ return True
+
+
+def _configure_gl_backend() -> None: # noqa: C901
+ """Auto-configure MuJoCo's OpenGL backend for headless environments.
+
+ MuJoCo reads MUJOCO_GL at import time to select the OpenGL backend:
+ - "egl" → EGL (GPU-accelerated offscreen, requires libEGL + NVIDIA driver)
+ - "osmesa" → OSMesa (CPU software rendering, slower but always works)
+ - "glfw" → GLFW (default, requires X11/Wayland display server)
+
+ This function MUST be called before `import mujoco`. Setting MUJOCO_GL
+ after import has no effect — the backend is locked at import time.
+
+ Never overrides a user-set MUJOCO_GL value.
+ """
+ if os.environ.get("MUJOCO_GL"):
+ logger.debug(f"MUJOCO_GL already set to '{os.environ['MUJOCO_GL']}', respecting user config")
+ return
+
+ if not _is_headless():
+ return
+
+ # Headless Linux — probe for EGL first (GPU-accelerated), then fall back to OSMesa (CPU)
+ try:
+ ctypes.cdll.LoadLibrary("libEGL.so.1")
+ os.environ["MUJOCO_GL"] = "egl"
+ logger.info("Headless environment detected — using MUJOCO_GL=egl (GPU-accelerated offscreen)")
+ return
+ except OSError:
+ pass
+
+ try:
+ ctypes.cdll.LoadLibrary("libOSMesa.so")
+ os.environ["MUJOCO_GL"] = "osmesa"
+ logger.info("Headless environment detected — using MUJOCO_GL=osmesa (CPU software rendering)")
+ return
+ except OSError:
+ pass
+
+ logger.warning(
+ "Headless environment detected but neither EGL nor OSMesa found. "
+ "MuJoCo rendering will likely fail. Install one of:\n"
+ " GPU: apt-get install libegl1-mesa-dev (or NVIDIA driver provides libEGL)\n"
+ " CPU: apt-get install libosmesa6-dev\n"
+ "Then set: export MUJOCO_GL=egl (or osmesa)"
+ )
+
+
+def _ensure_mujoco() -> "Any":
+ """Lazy import MuJoCo to avoid hard dependency.
+
+ Auto-configures the OpenGL backend for headless environments before
+ importing mujoco, since MUJOCO_GL must be set at import time.
+
+ Uses require_optional() for consistent dependency management across
+ the strands-robots package.
+ """
+ global _mujoco, _mujoco_viewer
+ if _mujoco is None:
+ _configure_gl_backend()
+ from strands_robots.utils import require_optional
+
+ _mujoco = require_optional(
+ "mujoco",
+ pip_install="mujoco",
+ extra="sim-mujoco",
+ purpose="MuJoCo simulation",
+ )
+ if _mujoco_viewer is None and not _is_headless():
+ try:
+ import mujoco.viewer as viewer
+
+ _mujoco_viewer = viewer
+ except ImportError:
+ pass
+ return _mujoco
+
+
+_rendering_available: bool | None = None
+
+
+def _can_render() -> bool:
+ """Check if MuJoCo offscreen rendering is available.
+
+ Probes once by creating a minimal Renderer. Result is cached.
+ Returns False on headless environments without EGL/OSMesa.
+ """
+ global _rendering_available
+ if _rendering_available is not None:
+ return _rendering_available
+
+ mj = _ensure_mujoco()
+ try:
+ model = mj.MjModel.from_xml_string("