diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 79b15d9..f389ad9 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -26,6 +26,11 @@ jobs: python-version: '3.12' cache: 'pip' + - name: Install system dependencies (OpenGL for MuJoCo) + run: | + sudo apt-get update + sudo apt-get install -y libosmesa6-dev ffmpeg + - name: Install dependencies run: | pip install --no-cache-dir hatch @@ -35,4 +40,6 @@ jobs: run: hatch run lint - name: Run tests + env: + MUJOCO_GL: osmesa run: hatch run test -x --strict-markers diff --git a/.gitignore b/.gitignore index 2e430c6..28eab63 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,7 @@ dist .strands_robots .coverage .ideation/ +MUJOCO_LOG.TXT +TASKS.md +TASKS_TO_FIX_85.md +.coverage.* diff --git a/AGENTS.md b/AGENTS.md index 6c1ad5a..59275a6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,4 +1,4 @@ -# AGENTS.md — strands-labs/robots +# AGENTS.md - strands-labs/robots ## Overview @@ -11,7 +11,7 @@ > **RULE**: ALWAYS use the project board to track work. When creating follow-up items, > create GitHub issues and add them to this board with Status + Priority set. -> Never track work only in local markdown — the board is the source of truth. +> Never track work only in local markdown - the board is the source of truth. ## Repository Structure @@ -61,20 +61,20 @@ hatch run format # ruff check --fix, ruff format ``` > **Note**: Hatch uses `uv` as installer (`installer = "uv"` in pyproject.toml) for faster -> environment creation. No manual uv install needed — hatch handles it. +> environment creation. No manual uv install needed - hatch handles it. ## Key Conventions -1. **Python 3.12+** — `requires-python = ">=3.12"` (LeRobot >=0.5.0 requires 3.12) -2. **Dependency bounds** — `>=1.0` deps: cap major. `<1.0` deps: cap minor. E.g. `lerobot>=0.5.0,<0.6.0` -3. **`__init__.py` must be thin** — exports only, no logic -4. **Imports at file top** — unless lazy-loading heavy deps with documented reason -5. **Raise on fatal errors** — never warn-and-continue if the system will behave unexpectedly -6. **No silent defaults on error** — returning zero-valued actions on failure is forbidden -7. **Use `require_optional()`** — from `strands_robots/utils.py` for all optional deps -8. **Integration tests required** — each policy needs `tests_integ/` tests with real inference -9. **Test behavior, not implementation** — assert on outputs, not internal state -10. **No dead code** — if it's not called and not part of base class, delete it +1. **Python 3.12+** - `requires-python = ">=3.12"` (LeRobot >=0.5.0 requires 3.12) +2. **Dependency bounds** - `>=1.0` deps: cap major. `<1.0` deps: cap minor. E.g. `lerobot>=0.5.0,<0.6.0` +3. **`__init__.py` must be thin** - exports only, no logic +4. **Imports at file top** - unless lazy-loading heavy deps with documented reason +5. **Raise on fatal errors** - never warn-and-continue if the system will behave unexpectedly +6. **No silent defaults on error** - returning zero-valued actions on failure is forbidden +7. **Use `require_optional()`** - from `strands_robots/utils.py` for all optional deps +8. **Integration tests required** - each policy needs `tests_integ/` tests with real inference +9. **Test behavior, not implementation** - assert on outputs, not internal state +10. **No dead code** - if it's not called and not part of base class, delete it ## PR Workflow @@ -94,7 +94,7 @@ hatch run format # ruff check --fix, ruff format `asimovinc/asimov-v0` which has `sim-model/xmls/asimov.xml` + `sim-model/assets/`. The `_safe_join` helper in `strands_robots/utils.py` guards against traversal (`..`). -- **Auto-download strategy** — every robot with an `asset` block must declare +- **Auto-download strategy** - every robot with an `asset` block must declare exactly one of: 1. `asset.robot_descriptions_module` (preferred) 2. `asset.source` with `type: "github"` diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..a4b6dde --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,217 @@ +# CHANGELOG + +All notable behavioural changes to `strands-robots` are logged here. Follows +[Keep a Changelog](https://keepachangelog.com/) conventions. + +## Unreleased - PR #85 (MuJoCo backend remediation) + +### Breaking + +These changes tighten the MuJoCo AgentTool contract. Legacy callers that +silently worked by accident will now receive a clear error instead: + +- **Router input validation**: The ``_dispatch_action`` router rejects any + top-level parameter that isn't declared on the target method. Passing + ``step(num_steps=5)`` (wrong name) or ``set_gravity(device="mps")`` + (stray kwarg) now errors with *"Unknown parameter X for action Y. + Valid: [...]"* instead of silently dropping the value. Methods whose + Python signature includes ``**kwargs`` (e.g. ``add_object``) keep their + pass-through semantics. +- **Missing required args**: produce *"Action X requires parameter Y."* + instead of a raw Python ``TypeError``. +- **Vector dimension validation**: ``position``, ``target``, ``origin``, + ``force``, ``torque``, ``gravity``, ``direction``, ``point``, ``orientation`` + (quaternion), and ``color`` (rgba) all validated for length + numeric + dtype before reaching numpy/MuJoCo. +- **Camera orientation**: ``add_camera(target=[x,y,z])`` is now honoured + by baking ``xyaxes`` into the MJCF ````. Previously the target + was silently dropped and every custom camera rendered a default view. + Degenerate case (``target == position``) errors. +- **Render camera validation**: ``render(camera_name="missing")`` errors + with *"Camera 'missing' not found."* instead of silently falling back + to the free camera while claiming to render from the named one. +- **Raycast zero-direction guard**: ``raycast(direction=[0,0,0])`` now + errors with *"direction vector is zero-length"*. Previously MuJoCo's + C-level ``mj_ray`` would abort the Python process. +- **apply_force requires a non-zero vector**: passing neither ``force`` + nor ``torque`` (or both zero) errors. Previously the call silently + succeeded with no effect. +- **step(n_steps<0)** rejected (previously it corrupted ``step_count``). +- **Negative mass / timestep / size** rejected per shape; previously + ``set_body_properties(mass=-1)`` and ``set_timestep(-0.01)`` silently + succeeded. +- **Plane objects auto-static**: ``add_object(shape="plane")`` now forces + ``is_static=True`` (planes are infinite in MuJoCo). Explicit + ``is_static=False`` on a plane is a hard error. +- **Duplicate camera name** rejected. Previously a second ``add_camera`` + with an existing name silently overwrote the registry entry while + leaving the old camera in the XML - ghost behaviour. Use + ``remove_camera`` + ``add_camera`` to replace. +- **stop_policy(robot_name='')** errors with *"stop_policy requires + 'robot_name'."* instead of silently matching the first robot. +- **eval_policy** requires an explicit ``robot_name``. Default + ``n_episodes`` lowered from 10 to 1. +- **register_urdf** validates the path: file must exist, be a file, and + be readable. Previously bad paths were cached and blew up later. + +### Recording backend split + +- ``start_recording`` (LeRobotDataset: parquet + per-camera MP4) still + requires the ``[lerobot]`` extra. Its error message when lerobot is + missing now points callers at ``start_cameras_recording`` for plain + MP4 (which runs under ``[sim-mujoco]`` alone via imageio-ffmpeg). +- No API change - the fix is informational. + +### Resource hygiene + +- ``destroy()`` and ``cleanup()`` now close renderers on the main thread + and empty the TLS cache. Previously each ``create_world/destroy`` + cycle leaked one ``mujoco.Renderer`` + its GL context (~33 MB per + cycle measured). Worker-thread renderers still release themselves on + thread teardown (we avoid cross-thread ``close()`` to prevent + ``cgl.free()`` SIGSEGVs on macOS). +- ``get_mass_matrix`` and ``get_contacts`` run ``mj_forward`` first so + values are valid immediately after a ``reset`` or ``add_robot`` + (previously returned stale / uninitialised memory). + +### Concurrency guards + +Write-mutations are now refused while a policy is running on any robot +in the world. Previously these could race the policy worker thread and +produce undefined behaviour or SIGSEGV: + + reset, set_gravity, set_timestep, set_joint_positions, + set_joint_velocities, apply_force, set_body_properties, + set_geom_properties, load_state, randomize, move_object + +The error now lists *which* robot(s) are active so the LLM can +``stop_policy`` on each without guessing: *"Cannot 'X' while a policy +is running on 'armA', 'armB'. Stop it first: action='stop_policy'."* + +### Concurrent per-robot policies (GH #114) + +Multiple ``start_policy`` calls on *different* robots now run +concurrently. MuJoCo physics is still serialized via ``self._lock`` +(``mj_step`` and ``ctrl[]`` writes are not thread-safe for concurrent +mutation), but each policy owns a disjoint slice of ``data.ctrl[]`` so +two VLA arms can operate in the same scene without semantic conflict. + +- ``start_policy("armA")`` + ``start_policy("armB")`` both succeed. + Second call no longer hits a global "policy already running" gate. +- ``start_policy`` on the *same* robot while its policy is active + still errors (unchanged). +- ``remove_robot("X")`` now gracefully stops X's own policy before + removing, instead of requiring a prior ``stop_policy("X")``. Still + errors if a *different* robot has an active policy (XML round-trip + invalidates cached IDs everywhere). +- New action ``list_policies_running`` returns the names of robots + with live policies. Prunes completed Futures as a side-effect. +- Completed policy Futures are no longer retained forever in + ``_policy_threads`` (GH #120 companion fix). + +### Policy-hook robustness (GH #117) + +``PolicyRunner.run`` previously caught *all* ``on_frame`` exceptions at +WARN level and kept iterating. A recording hook with a typo'd observation +key would log 500 lines and produce an empty dataset. Now we count +*consecutive* failures and abort the episode after a threshold (default +5, tunable via new ``max_onframe_failures`` kwarg). + +- A single transient failure still logs + continues; counter resets on + the next successful call. +- ``N`` consecutive failures raise ``RuntimeError`` so ``run()`` returns + ``status='error'`` with a clear message, preventing silent dataset + corruption. + +### Cleanup graceful shutdown (GH #116) + +``Simulation.cleanup()`` no longer races the policy worker. Previously +cleanup set ``self._world = None`` and called ``executor.shutdown(wait=False)`` +nearly simultaneously - a policy still inside ``mj_step`` segfaulted on +freed arrays. Now cleanup: + +1. Signals every live policy to stop (``policy_running = False``). +2. Awaits each outstanding Future with a bounded timeout (default 5s, + overridable via new ``cleanup(policy_stop_timeout=...)`` kwarg). +3. Only AFTER workers unwind do we null ``self._world`` and tear down + renderers / viewer / executor. + +Wedged workers that don't stop in time get logged as a warning - cleanup +proceeds rather than hanging the host process on exit. + +### Error message consistency + +- All "no world" paths return the same string: + *"No world. Call create_world (or load_scene) first."* +- Unknown-name errors use a uniform `` 'X' not found.`` shape + (Robot / Object / Body / Geom / Joint / Sensor / Camera / Checkpoint). +- ``stop_recording``, ``stop_cameras_recording``, ``stop_policy``, + ``close_viewer`` are now **idempotent**: calling them when nothing + is running returns ``status="success"`` with a *"Was not ..."* message + so callers can invoke them unconditionally. +- ``get_recording_status`` returns success in every lifecycle state + (no world / not recording / recording). + +### Deprecations + +- **add_robot name-as-registry fallback**: passing ``name="my_bot"`` + without ``urdf_path`` or ``data_config`` used to resolve ``my_bot`` in + the model registry. This now fires a ``DeprecationWarning``. Use + ``add_robot(name="...", data_config="")`` instead. Will + be removed next major release. + +### New / extended actions + +- ``forward_kinematics(body_name="X")`` filters to a single body. +- ``get_features(robot_name="X")`` filters to a single robot's joints + and actuators. +- ``set_geom_properties(geom_name="X")`` accepts the bare object name + as an alias for the injected ``"{name}_geom"``. +- ``render_all`` flags cameras whose frame has near-zero pixel variance + (``"⚠️ camera 'X': image appears empty (variance < 1)"``). +- ``render_depth`` surfaces MuJoCo's one-time ``ARB_clip_control`` + warning in the response text on macOS, so the LLM knows when depth + accuracy is reduced. +- ``render`` / ``render_depth``: width/height validated up front; + oversized requests get a plain-English message naming the actual + framebuffer cap (````) instead of MuJoCo's raw + error. +- ``run_policy`` / ``start_policy``: accept optional ``n_steps`` + (primary) or legacy ``max_steps`` as an alternative to + ``duration``+``control_frequency``. ``duration = n_steps / + control_frequency`` when ``n_steps`` is set. +- **New ``list_policies_running``** action returns the names of robots + with a live policy - pairs with the new concurrent-policy support + (see *Concurrent per-robot policies* above). +- ``randomize(randomize_physics=True)`` now reports per-body mass scales + and per-geom friction scales in the response (not just range + endpoints). +- ``get_contacts`` resolves unnamed geoms to + ``"/geom_"`` so contact pairs are always human-readable. +- ``get_sensor_data(sensor_name="X")`` on a model with no sensors now + distinguishes *"Sensor 'X' not found. Model has no sensors."* from + the generic "no sensors in model" success. + +### Tests + +- New: ``tests/simulation/mujoco/test_agenttool_contract.py`` - ~50 + tests that lock in router validation, tool_spec ↔ method parity, + unified error messages, idempotent stop family, ``mj_forward`` before + reads, render-dim validation, feature filters, camera duplicate + policy, plane auto-static, policy horizon unification, and more. +- New: ``tests/simulation/mujoco/test_renderer_hygiene.py`` - 4 tests + asserting TLS cache is emptied on ``destroy``, renderer reuse works + for identical ``(w,h)``, and ``create_world`` after ``destroy`` + rebuilds cleanly. +- New: ``tests/simulation/mujoco/test_recording_backends.py`` - 2 tests + (one skipped when ``lerobot`` IS installed) pinning the + MP4-without-lerobot backend. +- New: ``tests/simulation/mujoco/test_input_validation.py`` - 11 tests + for step/raycast/apply_force validation. +- New: ``tests_integ/test_resource_hygiene.py`` - 3 integration tests + (require ``psutil``): 50 create/destroy cycles grow RSS < 50 MB; 500 + renders at fixed dims grow RSS < 100 MB; TLS cache cleared on destroy. + +Test count: **256 → 362** (+106 new regression tests), zero +regressions. ``hatch run lint`` (ruff + mypy) clean across 102 source +files. diff --git a/README.md b/README.md index 0a93a93..7e4591d 100644 --- a/README.md +++ b/README.md @@ -493,7 +493,7 @@ agent.tool.gr00t_inference(action="stop", port=8000) | Variable | Description | Default | |----------|-------------|---------| | `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (MJCF, meshes) | `~/.strands_robots/assets/` | -| `GROOT_API_TOKEN` | API token for GR00T inference service | — | +| `GROOT_API_TOKEN` | API token for GR00T inference service | - | ### Cache Directory @@ -511,6 +511,108 @@ To clear the cache: `rm -rf ~/.strands_robots/assets/` To change the cache location: `export STRANDS_ASSETS_DIR=/path/to/custom/dir` +## Simulation (MuJoCo) + +`strands-robots` ships a MuJoCo-backed simulation AgentTool - 58 actions +exposed to any Strands agent for world composition, physics, policy +execution, and video/dataset recording. + +### Install + +```bash +pip install "strands-robots[sim-mujoco]" +# For LeRobotDataset recording (parquet + training data): +pip install "strands-robots[sim-mujoco,lerobot]" +``` + +### Quick start + +```python +from strands_robots.simulation import Simulation + +sim = Simulation(tool_name="sim", mesh=False) +sim.create_world() +sim.add_robot(name="arm", data_config="so100") +sim.add_object(name="cube", shape="box", position=[0.3, 0, 0.05]) +sim.add_camera(name="topdown", position=[0, 0, 1.5], target=[0, 0, 0]) + +sim.run_policy(robot_name="arm", policy_provider="mock", n_steps=200, + control_frequency=50.0, fast_mode=True) + +frame = sim.render(camera_name="topdown") # returns {status, content:[text, image]} +``` + +### 58 actions grouped + +- **World & objects**: `create_world`, `load_scene`, `add_robot`, + `add_object`, `move_object`, `list_objects`, `list_robots`, + `remove_robot`, `remove_object`, `destroy`, `reset`, `get_state`, + `save_state`, `load_state`, `list_checkpoints`. +- **Physics**: `step`, `set_timestep`, `set_gravity`, `apply_force`, + `raycast`, `multi_raycast`, `set_body_properties`, + `set_geom_properties`, `get_body_state`, `get_joint_state`, + `set_joint_positions`, `set_joint_velocities`, `forward_kinematics`, + `get_mass_matrix`, `inverse_dynamics`, `get_total_mass`, + `get_jacobian`, `get_energy`, `get_contacts`, `get_sensor_data`. +- **Cameras & rendering**: `add_camera`, `remove_camera`, `render`, + `render_depth`, `render_all`, `start_cameras_recording`, + `stop_cameras_recording`, `get_cameras_recording_status`. +- **Policy**: `start_policy`, `run_policy`, `stop_policy`, + `replay_episode`, `eval_policy`. +- **Randomization**: `randomize`. +- **Recording (LeRobotDataset)**: `start_recording`, `stop_recording`, + `get_recording_status`. +- **Introspection & util**: `get_features`, `list_urdfs`, `register_urdf`, + `export_xml`, `open_viewer`, `close_viewer`. + +### Common footguns + +- **Planes must be static.** `add_object(shape="plane")` auto-sets + `is_static=True`. Passing `is_static=False` on a plane is a hard error + (MuJoCo planes are infinite and can't have dynamic mass). +- **Camera orientation.** Pass `target=[x,y,z]` to look at a point - + without it the camera faces forward by default. `target == position` + errors. +- **MP4 vs dataset recording.** `start_cameras_recording` writes plain + MP4 per-camera and runs under `[sim-mujoco]` alone. `start_recording` + writes a LeRobotDataset (parquet + MP4 + schema) and requires the + `[lerobot]` extra. +- **Policy running → mutations blocked.** While a policy runs on any + robot, state-mutating actions (`reset`, `set_gravity`, joint setters, + `apply_force`, `set_body_properties`, `set_geom_properties`, + `load_state`, `randomize`, `move_object`) error with *"Cannot 'X' + while a policy is running."* Stop it first with + `stop_policy(robot_name='...')`. +- **Horizon parameters.** `run_policy` accepts either `duration` + + `control_frequency` (real-time) OR `n_steps` + `control_frequency` + (step-count). Pass `fast_mode=True` to skip the between-step sleep + during batch eval / data collection. +- **Name collisions.** Objects, bodies, robots, and cameras share the + MuJoCo name table. Robot joints and actuators are auto-namespaced as + `{robot_name}/{joint}` in multi-robot scenes. Object geoms are + injected as `{object_name}_geom`; `set_geom_properties` accepts the + bare object name as an alias. +- **Oversized render**: MuJoCo's offscreen framebuffer is capped by + `` in MJCF. Requesting a bigger + render now errors with a plain message naming the cap - either lower + the request or rebuild the model with larger dims. + +### Self-healing features + +- Unknown parameters are rejected with *"Unknown parameter X for action + Y. Valid: [...]"* so the LLM learns the correct name without trial- + and-error. +- Missing required parameters produce *"Action X requires parameter Y."* + (no Python `TypeError` leaks). +- Vector dimensions and numeric dtype are validated before MuJoCo sees + them (previously zero-length direction vectors crashed the Python + process via `mj_ray` C-level abort). +- `destroy()` and `cleanup()` empty the renderer TLS cache and shut down + the executor - no RSS growth across repeated create/destroy cycles. + +For the full action contract and test coverage see +`tests/simulation/mujoco/test_agenttool_contract.py`. + ## Contributing We welcome contributions! Please see: diff --git a/pyproject.toml b/pyproject.toml index f1a7090..3a4ce60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,13 +48,16 @@ groot-service = [ lerobot = [ "lerobot>=0.5.0,<0.6.0", ] -sim = [ +sim-mujoco = [ "robot_descriptions>=1.11.0,<2.0.0", + "mujoco>=3.0.0,<4.0.0", + "imageio>=2.28.0,<3.0.0", + "imageio-ffmpeg>=0.4.0,<1.0.0", ] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", - "strands-robots[sim]", + "strands-robots[sim-mujoco]", ] dev = [ "pytest>=6.0,<9.0.0", @@ -79,6 +82,7 @@ packages = ["strands_robots"] [tool.hatch.envs.default] installer = "uv" +features = ["all"] dependencies = [ "pytest>=6.0,<9.0.0", "pytest-cov>=4.0.0,<6.0.0", @@ -128,7 +132,7 @@ ignore_missing_imports = false # Third-party libs without type stubs [[tool.mypy.overrides]] -module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*"] +module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*", "imageio.*"] ignore_missing_imports = true # @tool decorator injects runtime signatures mypy cannot check @@ -161,6 +165,22 @@ module = ["strands_robots.registry.*"] warn_return_any = false disallow_untyped_defs = false +# MuJoCo simulation — mixins use cooperative self._world patterns +# attr-defined: Mixins access self._world/self._lock/etc. from Simulation (cooperative pattern) +# assignment: PEP 484 implicit Optional (= None on typed params) +# override: Subclass signatures extend base with extra params (orientation, mesh_path) +# misc: Multiple inheritance method resolution conflicts between mixin + ABC +[[tool.mypy.overrides]] +module = ["strands_robots.simulation.mujoco.*"] +disallow_untyped_defs = false +warn_return_any = false + +# Async utils and dataset recorder — thin wrappers with dynamic types +[[tool.mypy.overrides]] +module = ["strands_robots._async_utils", "strands_robots.dataset_recorder"] +disallow_untyped_defs = false +warn_return_any = false + # Test files — relaxed type checking for mocks, fixtures, and test utilities [[tool.mypy.overrides]] module = ["tests.*", "tests_integ.*"] diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py index 8ee9c41..7f5a54f 100644 --- a/strands_robots/__init__.py +++ b/strands_robots/__init__.py @@ -25,14 +25,11 @@ import warnings as _warnings from typing import Any -# ------------------------------------------------------------------ -# Light-weight imports — no torch / lerobot dependency -# ------------------------------------------------------------------ +# Light-weight imports - no torch / lerobot dependency from strands_robots.policies import MockPolicy, Policy, create_policy # noqa: F401 -# ------------------------------------------------------------------ # Lazy-loaded heavy symbols -# ------------------------------------------------------------------ + # Maps public name -> (module_path, attribute_name) _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "Robot": ("strands_robots.robot", "Robot"), diff --git a/strands_robots/_async_utils.py b/strands_robots/_async_utils.py new file mode 100644 index 0000000..478518b --- /dev/null +++ b/strands_robots/_async_utils.py @@ -0,0 +1,31 @@ +"""Async-to-sync helper for resolving coroutines in sync contexts.""" + +import asyncio +import concurrent.futures + +# Module-level executor reused across calls to avoid creating threads at high frequency. +# A single worker is sufficient - we only need to offload one asyncio.run() at a time. +_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="strands_async") + + +def _resolve_coroutine(coro_or_result): # type: ignore[no-untyped-def] + """Safely resolve a potentially-async result to a sync value. + + Handles three cases: + 1. Already a plain value → return as-is + 2. Coroutine, no running loop → asyncio.run() + 3. Coroutine, inside running loop → offload to reused thread + + Args: + coro_or_result: Either a coroutine or an already-resolved value. + + Returns: + The resolved (sync) value. + """ + if not asyncio.iscoroutine(coro_or_result): + return coro_or_result + try: + asyncio.get_running_loop() + return _EXECUTOR.submit(asyncio.run, coro_or_result).result() + except RuntimeError: + return asyncio.run(coro_or_result) diff --git a/strands_robots/assets/__init__.py b/strands_robots/assets/__init__.py index 4e9080c..a6faea5 100644 --- a/strands_robots/assets/__init__.py +++ b/strands_robots/assets/__init__.py @@ -4,7 +4,7 @@ MuJoCo Menagerie GitHub, cached in ``~/.strands_robots/assets/``. Override with ``STRANDS_ASSETS_DIR`` env var. -Implementation lives in ``assets/manager.py`` — this file is thin exports only. +Implementation lives in ``assets/manager.py`` - this file is thin exports only. """ from strands_robots.assets.manager import ( diff --git a/strands_robots/assets/download.py b/strands_robots/assets/download.py index a6ffa7d..93612d7 100644 --- a/strands_robots/assets/download.py +++ b/strands_robots/assets/download.py @@ -5,7 +5,7 @@ that delegates to :func:`download_robots` here. Strategy (in order of preference): - 1. ``robot_descriptions`` package — recommended by MuJoCo Menagerie. + 1. ``robot_descriptions`` package - recommended by MuJoCo Menagerie. 2. Shallow ``git clone`` fallback for Menagerie robots. 3. Custom GitHub repos for non-Menagerie robots. @@ -40,7 +40,7 @@ _ALLOWED_CLONE_URL_RE = re.compile(r"^https://github\.com/[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+\.git$") -# ── robot_descriptions integration ──────────────────────────────────── +# robot_descriptions integration def _robot_descriptions_available() -> bool: @@ -102,10 +102,7 @@ def _resolve_robot_descriptions_module(name: str, info: dict) -> str | None: return None -# ── Helpers ─────────────────────────────────────────────────────────── - - -#: Alias for backward compatibility — use :func:`strands_robots.utils.get_assets_dir`. +#: Alias for backward compatibility - use :func:`strands_robots.utils.get_assets_dir`. get_user_assets_dir = get_assets_dir @@ -151,7 +148,7 @@ def _get_source(info: dict[str, Any] | None) -> dict[str, Any]: def _shallow_clone(repo_url: str, dest: str, *, timeout: int = 120) -> None: """Shallow-clone *repo_url* into *dest*. - Only HTTPS ``github.com`` URLs are accepted — ``ssh://``, ``git://``, + Only HTTPS ``github.com`` URLs are accepted - ``ssh://``, ``git://``, ``file://``, and other schemes are rejected to prevent command-injection and SSRF risks. @@ -174,7 +171,7 @@ def _shallow_clone(repo_url: str, dest: str, *, timeout: int = 120) -> None: # Filenames/patterns that are safe to strip from an upstream source tree before # we copy it into the user's asset cache. Filtering at *copy* time (rather than # deleting afterwards) means we never touch files that may already exist in *dst* -# — which matters when the user keeps notes/README alongside assets. +# - which matters when the user keeps notes/README alongside assets. _COPY_CLEAN_SKIP = frozenset({"README.md", "LICENSE", "CHANGELOG.md"}) _COPY_CLEAN_SUFFIX = (".png", ".jpg", ".jpeg") @@ -195,9 +192,6 @@ def _ignore(_dir: str, names: list[str]) -> list[str]: shutil.copytree(str(src), str(dst), dirs_exist_ok=True, ignore=_ignore) -# ── Download backends ───────────────────────────────────────────────── - - def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> dict[str, str]: """Download robots using the ``robot_descriptions`` package. @@ -234,9 +228,9 @@ def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> if expected_xml.exists(): results[name] = "downloaded" continue - # Stale symlink — remove and re-download via git + # Stale symlink - remove and re-download via git dst.unlink() - results[name] = f"failed: stale symlink — {info['asset']['model_xml']} not found in {package_path}" + results[name] = f"failed: stale symlink - {info['asset']['model_xml']} not found in {package_path}" continue if dst.exists() or dst.is_symlink(): dst.unlink() if dst.is_symlink() else shutil.rmtree(str(dst)) @@ -251,7 +245,7 @@ def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> if not expected_xml.exists(): logger.warning( "robot_descriptions module '%s' linked for %s but " - "expected XML '%s' not found — falling back to git", + "expected XML '%s' not found - falling back to git", module_name, name, info["asset"]["model_xml"], @@ -261,7 +255,7 @@ def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> else: shutil.rmtree(str(dst), ignore_errors=True) results[name] = ( - f"failed: XML mismatch — module '{module_name}' does not contain {info['asset']['model_xml']}" + f"failed: XML mismatch - module '{module_name}' does not contain {info['asset']['model_xml']}" ) continue @@ -333,7 +327,7 @@ def _download_from_github(name: str, info: dict, dest_dir: Path) -> str: return f"failed: {exc}" -# ── Orchestrator ────────────────────────────────────────────────────── +# Orchestrator def auto_download_robot(name: str, info: dict[str, Any]) -> bool: @@ -379,20 +373,20 @@ def download_robots( """Download robot model assets from their respective sources. Strategy (in order of preference): - 1. ``robot_descriptions`` package — recommended by MuJoCo Menagerie. + 1. ``robot_descriptions`` package - recommended by MuJoCo Menagerie. 2. Shallow ``git clone`` fallback for Menagerie robots. 3. Custom GitHub repos for non-Menagerie robots. Args: names: Robot names to download (``None`` = all sim robots). - category: Filter by category (arm, humanoid, mobile, …). + category: Filter by category (arm, humanoid, mobile, ...). force: Re-download even if present. Returns: Dict with downloaded/skipped/failed counts, names, and details. """ dest_dir = get_user_assets_dir() - # Filter None values — get_robot() can return None for unknown names + # Filter None values - get_robot() can return None for unknown names all_sim: dict[str, dict[str, Any]] = { r["name"]: info for r in registry_list_robots(mode="sim") if (info := get_robot(r["name"])) is not None } diff --git a/strands_robots/assets/manager.py b/strands_robots/assets/manager.py index ca610a6..34f37ce 100644 --- a/strands_robots/assets/manager.py +++ b/strands_robots/assets/manager.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -# Module-level conditional import — keeps manager.py importable in +# Module-level conditional import - keeps manager.py importable in # environments where the optional ``robot_descriptions`` package (and its # transitive heavyweight deps like ``GitPython``) are not installed. # When ``download`` is not available, auto-download simply returns False. @@ -32,9 +32,9 @@ _auto_download_robot_impl = None # type: ignore[assignment] -# ───────────────────────────────────────────────────────────────────── +# # Model path resolution (delegates to registry) -# ───────────────────────────────────────────────────────────────────── +# def _auto_download_robot(name: str, info: dict) -> bool: @@ -114,7 +114,7 @@ def _resolve_candidates(asset_dir_name: str, xml_file: str, name: str) -> list[P def is_robot_asset_present(name: str) -> bool: """Check whether a robot's model XML exists on disk without triggering downloads. - Pure filesystem check — no auto-download, no mesh walk, no network. + Pure filesystem check - no auto-download, no mesh walk, no network. Use this for status queries (e.g. ``download_assets(action="status")``) where you need to quickly check presence without side effects. @@ -194,7 +194,7 @@ def resolve_model_path( # Check user-registered asset path first (highest priority). # ``xml_file`` comes from user_robots.json, so we still gate it through # :func:`safe_join` to block path traversal even for user-authored entries - # (defense in depth — protects against a compromised user_robots.json and + # (defense in depth - protects against a compromised user_robots.json and # keeps the trust boundary identical to the built-in registry path). user_path = info.get("_user_asset_path") if user_path: @@ -214,7 +214,7 @@ def resolve_model_path( candidates.extend(_resolve_candidates(asset_dir_name, xml_file, name)) if not candidates: - # No XML found at all — try auto-download, then re-search + # No XML found at all - try auto-download, then re-search logger.info("No XML found for %s, attempting auto-download...", name) if _auto_download_robot(name, info): candidates.extend(_resolve_candidates(asset_dir_name, xml_file, name)) @@ -230,7 +230,7 @@ def resolve_model_path( logger.debug("Resolved %s → %s (has meshes)", name, path) return Path(path) - # XML found but no meshes — auto-download and re-check + # XML found but no meshes - auto-download and re-check logger.info("XML found for %s but no meshes, attempting auto-download...", name) if _auto_download_robot(name, info): # Re-scan after download (new symlinks may have appeared) @@ -305,7 +305,7 @@ def list_available_robots() -> list[dict]: name = r["name"] present = is_robot_asset_present(name) info = get_robot(name) or {} - # Only resolve full path when asset is present — avoids download attempts + # Only resolve full path when asset is present - avoids download attempts path = resolve_model_path(name) if present else None robots.append( { diff --git a/strands_robots/dataset_recorder.py b/strands_robots/dataset_recorder.py new file mode 100644 index 0000000..d38f01b --- /dev/null +++ b/strands_robots/dataset_recorder.py @@ -0,0 +1,515 @@ +"""LeRobotDataset recorder bridge for strands-robots. + +Wraps LeRobotDataset so that both robot.py (real hardware) and +simulation.py (MuJoCo) can produce training-ready datasets with +a single add_frame() call per control step. + +Usage: + recorder = DatasetRecorder.create( + repo_id="user/my_dataset", + fps=30, + robot_features=robot.observation_features, + action_features=robot.action_features, + task="pick up the red cube", + ) + # In control loop: + recorder.add_frame(observation, action, task="pick up the red cube") + # End of episode: + recorder.save_episode() + # Optionally: + recorder.push_to_hub() +""" + +import functools +import logging +import sys +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + +# Lazy check for LeRobot availability +# We must NOT import lerobot at module level because it pulls in +# `datasets` → `pandas`, which can crash with a numpy ABI mismatch on +# systems where the system pandas was compiled against an older numpy +# (e.g. JetPack / Jetson with system pandas 2.1.4 + pip numpy 2.x). + + +@functools.lru_cache(maxsize=1) +def has_lerobot_dataset() -> bool: + """Check if lerobot is available. Result is cached after first call.""" + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: F401 + + return True + except (ImportError, ValueError, RuntimeError) as exc: + logger.debug("lerobot not available: %s", exc) + return False + + +def _get_lerobot_dataset_class(): + """Import and return LeRobotDataset class, or raise ImportError. + + Supports test mocking: if ``strands_robots.dataset_recorder.LeRobotDataset`` + has been set (by a test mock), returns that class directly. + """ + # Support test mocking: check module-level overrides + this_module = sys.modules[__name__] + + # If a test injected a mock LeRobotDataset class, use it + mock_cls = getattr(this_module, "LeRobotDataset", None) + if mock_cls is not None: + return mock_cls + + # Actual import + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + return LeRobotDataset + except (ImportError, ValueError, RuntimeError) as exc: + raise ImportError( + f"lerobot not available ({exc}). Install with: pip install lerobot\nRequired for LeRobotDataset recording." + ) from exc + + +class DatasetRecorder: + """Bridge between strands-robots control loops and LeRobotDataset. + + Handles the full lifecycle: + 1. create() - build LeRobotDataset with correct features + 2. add_frame() - called every control step with obs + action + 3. save_episode() - finalize episode (encodes video, writes parquet) + 4. push_to_hub() - upload to HuggingFace + + Works for both real hardware (robot.py) and simulation (simulation.py). + """ + + def __init__(self, dataset, task: str = "", strict: bool = True): + self.dataset = dataset + self.default_task = task + self.frame_count = 0 + self.dropped_frame_count = 0 + self.strict = strict + self.episode_count = 0 + self._closed = False + self._cached_state_keys: list[str] | None = None + self._cached_action_keys: list[str] | None = None + + @classmethod + def create( + cls, + repo_id: str, + fps: int = 30, + robot_type: str = "unknown", + robot_features: dict[str, Any] | None = None, + action_features: dict[str, Any] | None = None, + camera_keys: list[str] | None = None, + joint_names: list[str] | None = None, + task: str = "", + root: str | None = None, + use_videos: bool = True, + vcodec: str = "libsvtav1", + streaming_encoding: bool = True, + image_writer_threads: int = 4, + video_backend: str = "auto", + ) -> "DatasetRecorder": + """Create a new DatasetRecorder with auto-detected features. + + Args: + repo_id: HuggingFace dataset ID (e.g. "user/my_dataset") + fps: Recording frame rate + robot_type: Robot type string (e.g. "so100", "panda") + robot_features: Dict of observation feature names → types + (from robot.observation_features or sim joint names) + action_features: Dict of action feature names → types + camera_keys: List of camera names (images become video features) + joint_names: List of joint names (alternative to robot_features for sim) + task: Default task description + root: Local directory for dataset storage + use_videos: Encode camera frames as video (True) or keep as images + vcodec: Video codec (h264, hevc, libsvtav1) + streaming_encoding: Stream-encode video during capture + image_writer_threads: Threads for writing image frames + video_backend: Video backend for encoding ("auto" for HW encoder auto-detect) + """ + # Lazy import - this is where we actually need lerobot + LeRobotDatasetCls = _get_lerobot_dataset_class() + + # Build features dict in LeRobot format + features = cls._build_features( + robot_features=robot_features, + action_features=action_features, + camera_keys=camera_keys, + joint_names=joint_names, + use_videos=use_videos, + ) + + logger.info(f"Creating LeRobotDataset: {repo_id} @ {fps}fps, {len(features)} features, robot_type={robot_type}") + + # Build kwargs, skip unsupported params for this LeRobot version + create_kwargs = dict( + repo_id=repo_id, + fps=fps, + root=root, + robot_type=robot_type, + features=features, + use_videos=use_videos, + image_writer_threads=image_writer_threads, + vcodec=vcodec, + ) + # streaming_encoding only in newer LeRobot versions + import inspect + + create_sig = inspect.signature(LeRobotDatasetCls.create) + if "streaming_encoding" in create_sig.parameters: + create_kwargs["streaming_encoding"] = streaming_encoding + if "video_backend" in create_sig.parameters: + create_kwargs["video_backend"] = video_backend + dataset = LeRobotDatasetCls.create(**create_kwargs) + + recorder = cls(dataset=dataset, task=task) + logger.info("DatasetRecorder ready: %s", repo_id) + return recorder + + @classmethod + def _build_features( + cls, + robot_features: dict | None = None, + action_features: dict | None = None, + camera_keys: list[str] | None = None, + joint_names: list[str] | None = None, + use_videos: bool = True, + ) -> dict[str, Any]: + """Build LeRobot v3-compatible features dict. + + LeRobot v3 features format: + { + "observation.images.camera_name": {"dtype": "video", "shape": (C, H, W), "names": [...]}, + "observation.state": {"dtype": "float32", "shape": (N,), "names": [...]}, + "action": {"dtype": "float32", "shape": (N,), "names": [...]}, + } + + Note: "names" must be a flat list of strings, NOT a dict like {"motors": [...]}. + """ + features = {} + + # Observation: cameras → video/image features + if camera_keys: + for cam_name in camera_keys: + key = f"observation.images.{cam_name}" + dtype = "video" if use_videos else "image" + features[key] = { + "dtype": dtype, + "shape": ( + 3, + 480, + 640, + ), # CHW default, actual shape set on first frame + "names": ["channels", "height", "width"], + } + + # Observation: state (joint positions) + state_dim = 0 + state_names = [] + if robot_features: + # Count scalar features (exclude cameras) + state_keys = [ + k + for k, v in robot_features.items() + if not isinstance(v, dict) or v.get("dtype") not in ("image", "video") + ] + state_dim = len(state_keys) + state_names = state_keys + elif joint_names: + state_dim = len(joint_names) + state_names = list(joint_names) + + if state_dim > 0: + features["observation.state"] = { + "dtype": "float32", + "shape": (state_dim,), + "names": state_names, + } + + # Action + action_dim = 0 + action_names = [] + if action_features: + action_keys = [ + k + for k, v in action_features.items() + if not isinstance(v, dict) or v.get("dtype") not in ("image", "video") + ] + action_dim = len(action_keys) + action_names = action_keys + elif joint_names: + action_dim = len(joint_names) + action_names = list(joint_names) + elif state_dim > 0: + action_dim = state_dim # Same dim as state by default + action_names = state_names[:] + + if action_dim > 0: + features["action"] = { + "dtype": "float32", + "shape": (action_dim,), + "names": action_names[:action_dim], + } + + return features + + def add_frame( + self, + observation: dict[str, Any], + action: dict[str, Any], + task: str | None = None, + camera_keys: list[str] | None = None, + ) -> None: + """Add a single control-loop frame to the dataset. + + This is the key method - called every step in the control loop. + + Args: + observation: Raw observation dict from robot/sim + (joint_name → float, camera_name → np.ndarray) + action: Action dict (joint_name → float) + task: Task description (uses default if None) + camera_keys: Which keys in observation are camera images + """ + if self._closed: + return + + frame = {} + + # Detect camera vs state keys + if camera_keys is None: + camera_keys = [k for k, v in observation.items() if isinstance(v, np.ndarray) and v.ndim >= 2] + + state_keys = [k for k in observation.keys() if k not in camera_keys] + + # Camera images → observation.images.{name} + for cam_key in camera_keys: + img = observation[cam_key] + if isinstance(img, np.ndarray): + # LeRobot expects HWC uint8 for add_frame + if img.dtype != np.uint8: + img = (np.clip(img, 0, 1) * 255).astype(np.uint8) + frame[f"observation.images.{cam_key}"] = img + + # State → observation.state (flattened vector) + # Use feature schema ordering to match the dataset schema declared in _build_features(). + if state_keys: + state_vals = [] + if self._cached_state_keys is None: + feat = self.dataset.features.get("observation.state", {}) + state_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", []) + self._cached_state_keys = state_names if state_names else sorted(state_keys) + + for k in self._cached_state_keys: + v = observation.get(k) + if v is None: + state_vals.append(0.0) + elif isinstance(v, (int, float)): + state_vals.append(float(v)) + elif isinstance(v, np.ndarray) and v.ndim == 0: + state_vals.append(float(v)) + elif isinstance(v, (list, np.ndarray)): + arr = np.asarray(v, dtype=np.float32).flatten() + state_vals.extend(arr.tolist()) + if state_vals: + frame["observation.state"] = np.array(state_vals, dtype=np.float32) + + # Action → flattened vector + # Use feature schema ordering for actions too. + if action: + action_vals = [] + if self._cached_action_keys is None: + feat = self.dataset.features.get("action", {}) + action_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", []) + self._cached_action_keys = action_names if action_names else sorted(action.keys()) + + for k in self._cached_action_keys: + v = action.get(k) + if v is None: + action_vals.append(0.0) + elif isinstance(v, (int, float)): + action_vals.append(float(v)) + elif isinstance(v, np.ndarray) and v.ndim == 0: + action_vals.append(float(v)) + elif isinstance(v, (list, np.ndarray)): + arr = np.asarray(v, dtype=np.float32).flatten() + action_vals.extend(arr.tolist()) + if action_vals: + frame["action"] = np.array(action_vals, dtype=np.float32) + + # Task (mandatory for LeRobot v3) + frame["task"] = task or self.default_task or "untitled" + + # Reconcile camera keys between frame and feature schema + # Normalize namespaced camera keys (e.g. "arm0/wrist_cam" → "arm0__wrist_cam") + # to match the schema declared in _build_features. MuJoCo uses "/" as a + # namespace separator for multi-robot cameras, but LeRobot feature names + # cannot contain "/" (reserved for nested-feature addressing). + declared_cam_keys = {k for k in self.dataset.features if k.startswith("observation.images.")} + frame_cam_keys = {k for k in list(frame.keys()) if k.startswith("observation.images.")} + for cam_key in frame_cam_keys: + normalized = cam_key.replace("/", "__") + if normalized != cam_key and normalized in declared_cam_keys: + frame[normalized] = frame.pop(cam_key) + + # Strip undeclared cameras (keys present in obs but not registered in + # _build_features). This avoids LeRobot's "Extra features" error. + # Declared-but-missing cameras (e.g. when a render fails) are left alone - + # LeRobot tolerates absent columns and the episode simply won't have that + # camera's data. + frame_cam_keys_final = {k for k in frame if k.startswith("observation.images.")} + for extra in frame_cam_keys_final - declared_cam_keys: + del frame[extra] + + # Add to dataset + try: + self.dataset.add_frame(frame) + self.frame_count += 1 + except Exception as e: + if self.strict: + raise # Fail-fast per AGENTS.md convention #5 + self.dropped_frame_count += 1 + n = self.dropped_frame_count + # Log at 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, then every 1000 + if (n & (n - 1)) == 0 or n % 1000 == 0: + logger.warning( + "add_frame failed (frame %d, dropped %d): %s", + self.frame_count, + self.dropped_frame_count, + e, + ) + + def save_episode(self) -> dict[str, Any]: + """Finalize current episode - writes parquet, encodes video, computes stats. + + LeRobot v3: save_episode() takes no task argument. Tasks are stored + per-frame in the episode buffer via add_frame(). + + Returns: + Dict with episode info + """ + if self._closed: + return {"status": "error", "message": "Recorder closed"} + + try: + self.dataset.save_episode() + self.episode_count += 1 + ep_frames = self.frame_count # Total frames so far + logger.info(f"Episode {self.episode_count} saved: {ep_frames} total frames") + return { + "status": "success", + "episode": self.episode_count, + "total_frames": ep_frames, + } + except Exception as e: + logger.error("save_episode failed: %s", e) + return {"status": "error", "message": str(e)} + + def finalize(self) -> None: + """Finalize the dataset (close parquet writers, flush metadata).""" + if self._closed: + return + try: + self.dataset.finalize() + except Exception as e: + logger.warning("finalize warning: %s", e) + self._closed = True + + def push_to_hub( + self, + tags: list[str] | None = None, + private: bool = False, + ) -> dict[str, Any]: + """Push dataset to HuggingFace Hub. + + Args: + tags: Optional tags for the dataset + private: Upload as private dataset + + Returns: + Dict with push status + """ + try: + self.dataset.push_to_hub(tags=tags, private=private) + logger.info("Dataset pushed to hub: %s", self.dataset.repo_id) + return { + "status": "success", + "repo_id": self.dataset.repo_id, + "episodes": self.episode_count, + "frames": self.frame_count, + } + except Exception as e: + logger.error("push_to_hub failed: %s", e) + return {"status": "error", "message": str(e)} + + @property + def repo_id(self) -> str: + return self.dataset.repo_id + + @property + def root(self) -> str: + return str(self.dataset.root) + + def __repr__(self) -> str: + return f"DatasetRecorder(repo_id={self.repo_id}, episodes={self.episode_count}, frames={self.frame_count})" + + +# Shared replay-episode helpers + + +def load_lerobot_episode(repo_id: str, episode: int = 0, root: str | None = None): + """Load a LeRobotDataset and resolve the frame range for an episode. + + Returns: + Tuple of (dataset, episode_start, episode_length) on success. + + Raises: + ImportError: If lerobot is not installed. + ValueError: If the episode is out of range or has no frames. + """ + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + ds = LeRobotDataset(repo_id=repo_id, root=root) + + num_episodes = ds.meta.total_episodes if hasattr(ds.meta, "total_episodes") else len(ds.meta.episodes) + if episode >= num_episodes: + raise ValueError(f"Episode {episode} out of range (0-{num_episodes - 1})") + + episode_start = 0 + episode_length = 0 + try: + if hasattr(ds, "episode_data_index"): + from_idx = ds.episode_data_index["from"][episode].item() + to_idx = ds.episode_data_index["to"][episode].item() + episode_start = from_idx + episode_length = to_idx - from_idx + else: + for i in range(episode): + ep_info = ds.meta.episodes[i] if hasattr(ds.meta, "episodes") else {} + episode_start += ep_info.get("length", 0) + ep_info = ds.meta.episodes[episode] if hasattr(ds.meta, "episodes") else {} + episode_length = ep_info.get("length", 0) + except Exception: + # Last resort: scan frames to find episode boundaries + for idx in range(len(ds)): + frame = ds[idx] + frame_ep = frame.get("episode_index", -1) if hasattr(frame, "get") else -1 + if hasattr(frame_ep, "item"): + frame_ep = frame_ep.item() + if frame_ep == episode: + if episode_length == 0: + episode_start = idx + episode_length += 1 + elif episode_length > 0: + break + + if episode_length == 0: + raise ValueError(f"Episode {episode} has no frames") + + return ds, episode_start, episode_length diff --git a/strands_robots/policies/__init__.py b/strands_robots/policies/__init__.py index 6cffe18..048e268 100644 --- a/strands_robots/policies/__init__.py +++ b/strands_robots/policies/__init__.py @@ -1,6 +1,6 @@ """Policy Abstraction for Universal VLA Support. -Plugin-based registry — all provider definitions live in registry/policies.json. +Plugin-based registry - all provider definitions live in registry/policies.json. No hardcoded if/elif chains. New providers are auto-discovered or registered at runtime. Built-in providers (see policies.json for full list): diff --git a/strands_robots/policies/base.py b/strands_robots/policies/base.py index a082f4c..c2e5ed6 100644 --- a/strands_robots/policies/base.py +++ b/strands_robots/policies/base.py @@ -54,6 +54,18 @@ def set_robot_state_keys(self, robot_state_keys: list[str]) -> None: """Configure the policy with robot state keys.""" pass + @property + def requires_images(self) -> bool: + """Whether this policy needs camera frames in its observation. + + Default True (most VLA policies do). Subclasses that only consume + joint state (e.g. ``MockPolicy``, pure-IK controllers, scripted + trajectories) can return ``False`` to let the simulation skip + expensive camera rendering - a ~10x throughput win at 500Hz when + no cameras are needed. + """ + return True + @property @abstractmethod def provider_name(self) -> str: diff --git a/strands_robots/policies/factory.py b/strands_robots/policies/factory.py index 062978d..e25c842 100644 --- a/strands_robots/policies/factory.py +++ b/strands_robots/policies/factory.py @@ -1,4 +1,4 @@ -"""Policy factory — create_policy() and runtime registration.""" +"""Policy factory - create_policy() and runtime registration.""" import logging import os @@ -9,9 +9,9 @@ logger = logging.getLogger(__name__) -# ───────────────────────────────────────────────────────────────────── +# # Runtime registration (for user-defined providers not in JSON) -# ───────────────────────────────────────────────────────────────────── +# _runtime_registry: dict[str, Callable[[], type[Policy]]] = {} _runtime_aliases: dict[str, str] = {} diff --git a/strands_robots/policies/groot/__init__.py b/strands_robots/policies/groot/__init__.py index db09efe..8884c0c 100644 --- a/strands_robots/policies/groot/__init__.py +++ b/strands_robots/policies/groot/__init__.py @@ -1,4 +1,4 @@ -"""GR00T Policy — NVIDIA GR00T N1.5 and N1.6 support. +"""GR00T Policy - NVIDIA GR00T N1.5 and N1.6 support. Two inference modes: diff --git a/strands_robots/policies/groot/client.py b/strands_robots/policies/groot/client.py index e8cf0e6..5bf4257 100644 --- a/strands_robots/policies/groot/client.py +++ b/strands_robots/policies/groot/client.py @@ -1,4 +1,4 @@ -"""GR00T inference client — ZMQ client for inference-service communication. +"""GR00T inference client - ZMQ client for inference-service communication. Handles serialization of numpy arrays and ModalityConfig objects over ZMQ using msgpack with custom encode/decode hooks. @@ -138,7 +138,7 @@ def ping(self) -> bool: """Check server connectivity. Returns True if the server responds, False otherwise. - Does NOT auto-reconnect — call :meth:`reconnect` explicitly if needed. + Does NOT auto-reconnect - call :meth:`reconnect` explicitly if needed. """ try: self.call_endpoint("ping") @@ -185,7 +185,7 @@ def get_action(self, observations: dict[str, Any]) -> dict[str, Any]: is currently empty in all upstream embodiments. """ response = self.call_endpoint("get_action", {"observation": observations, "options": None}) - # N1.6/N1.7 servers return a (action_dict, info_dict) tuple—msgpack + # N1.6/N1.7 servers return a (action_dict, info_dict) tuple - msgpack # decodes tuples as lists, so we may see either shape here. if isinstance(response, list | tuple) and len(response) == 2: action, _info = response diff --git a/strands_robots/policies/groot/data_config.py b/strands_robots/policies/groot/data_config.py index e5fc879..6dc30d3 100644 --- a/strands_robots/policies/groot/data_config.py +++ b/strands_robots/policies/groot/data_config.py @@ -1,4 +1,4 @@ -"""GR00T data configuration — typed embodiment key mappings. +"""GR00T data configuration - typed embodiment key mappings. Provides :class:`Gr00tDataConfig` dataclasses and an ``_extends`` inheritance mechanism so new robot configs can be defined by overriding only what differs @@ -59,9 +59,7 @@ def modality_config(self) -> dict[str, ModalityConfig]: } -# --------------------------------------------------------------------------- # Config resolution with _extends inheritance -# --------------------------------------------------------------------------- def _resolve_config(name: str, definitions: dict) -> Gr00tDataConfig: @@ -88,9 +86,7 @@ def _resolve_config(name: str, definitions: dict) -> Gr00tDataConfig: return Gr00tDataConfig(**merged) -# --------------------------------------------------------------------------- # Load configs from JSON -# --------------------------------------------------------------------------- _CONFIG_FILE = Path(__file__).parent / "data_configs.json" diff --git a/strands_robots/policies/groot/policy.py b/strands_robots/policies/groot/policy.py index 162a80f..8d23bae 100644 --- a/strands_robots/policies/groot/policy.py +++ b/strands_robots/policies/groot/policy.py @@ -1,4 +1,4 @@ -"""GR00T policy — N1.5/N1.6 service and local inference. +"""GR00T policy - N1.5/N1.6 service and local inference. Implements :class:`~strands_robots.policies.base.Policy` for NVIDIA GR00T models. @@ -33,9 +33,7 @@ logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- # Isaac-GR00T version detection -# --------------------------------------------------------------------------- _GROOT_VERSION: str | None = None # "n1.5", "n1.6", "n1.7", or None @@ -61,7 +59,7 @@ def _detect_groot_version(*, force: bool = False) -> str | None: # Reset before re-detection _GROOT_VERSION = None - # N1.7 first — the new Cosmos-Reason2-2B backbone lives here. + # N1.7 first - the new Cosmos-Reason2-2B backbone lives here. # Detecting by subpackage (not enum values) keeps the probe cheap. try: if importlib.util.find_spec("gr00t.model.gr00t_n1d7") is not None: @@ -90,9 +88,7 @@ def _detect_groot_version(*, force: bool = False) -> str | None: return None -# --------------------------------------------------------------------------- # Mapping dataclasses -# --------------------------------------------------------------------------- @dataclass(frozen=True) @@ -100,8 +96,8 @@ class ObservationMapping: """Maps robot sensor names → model modality keys. Attributes: - video: ``{robot_camera: model_video_key}`` — bare, no prefix. - state: ``{robot_state: model_state_key}`` — bare, no prefix. + video: ``{robot_camera: model_video_key}`` - bare, no prefix. + state: ``{robot_state: model_state_key}`` - bare, no prefix. language_key: Model's language key (e.g. ``"task"``). """ @@ -139,7 +135,7 @@ class ActionMapping: """Maps model action keys → robot actuator names. Attributes: - actions: ``{model_action_key: robot_actuator}`` — bare, no prefix. + actions: ``{model_action_key: robot_actuator}`` - bare, no prefix. """ actions: dict[str, str] = field(default_factory=dict) @@ -152,9 +148,7 @@ def validate(self, modality_configs: dict) -> None: raise ValueError(f"Action mapping: model key '{model_key}' not in model: {sorted(model_action)}") -# --------------------------------------------------------------------------- # Auto-inference (exact name match → positional fallback) -# --------------------------------------------------------------------------- def _auto_infer_observation_mapping( @@ -214,9 +208,7 @@ def _match_keys(ours: list[str], model: list[str], label: str) -> dict[str, str] return mapping -# --------------------------------------------------------------------------- # Parse user-provided flat mapping dicts -# --------------------------------------------------------------------------- def _parse_observation_mapping( @@ -247,13 +239,11 @@ def _parse_action_mapping(flat: dict[str, str]) -> ActionMapping: return ActionMapping(actions={k.removeprefix("action."): v for k, v in flat.items()}) -# --------------------------------------------------------------------------- # Gr00tPolicy -# --------------------------------------------------------------------------- class Gr00tPolicy(Policy): - """GR00T policy — service mode and local inference (N1.5/N1.6). + """GR00T policy - service mode and local inference (N1.5/N1.6). For **local mode**, loads the model directly and talks its native nested-dict format. Robot↔model key translation is done by explicit mappings. @@ -317,7 +307,7 @@ def __init__( self._groot_version = groot_version or _detect_groot_version() self._strict = strict - # DOF per model state key — discovered from model at load time + # DOF per model state key - discovered from model at load time self._model_state_dof: dict[str, int] = {} # Raw user mappings (parsed after model load) @@ -348,9 +338,7 @@ def __init__( self.data_config_name, ) - # ------------------------------------------------------------------ # Mapping initialization - # ------------------------------------------------------------------ def _init_mappings(self) -> None: """Initialize observation/action mappings after model load.""" @@ -463,16 +451,14 @@ def _discover_model_state_dof(self, mmc: dict) -> None: missing = all_keys - discovered if missing: logger.warning( - "Could not discover DOF for state keys: %s — these will not be zero-filled if unmapped", + "Could not discover DOF for state keys: %s - these will not be zero-filled if unmapped", sorted(missing), ) if self._model_state_dof: logger.info("Model state DOF: %s", self._model_state_dof) - # ------------------------------------------------------------------ # Model loading - # ------------------------------------------------------------------ def _load_local_policy(self, model_path: str, embodiment_tag: str, device: str): if self._groot_version == "n1.7": @@ -504,7 +490,7 @@ def _load_n15(self, model_path: str, embodiment_tag: str, device: str): logger.info("GR00T N1.5 loaded from %s", model_path) def _load_n16(self, model_path: str, embodiment_tag: str, device: str): - """Load N1.6 — uses Gr00tPolicy directly (NOT SimPolicyWrapper).""" + """Load N1.6 - uses Gr00tPolicy directly (NOT SimPolicyWrapper).""" from gr00t.data.embodiment_tags import EmbodimentTag from gr00t.policy.gr00t_policy import Gr00tPolicy as N16Policy @@ -518,7 +504,7 @@ def _load_n16(self, model_path: str, embodiment_tag: str, device: str): logger.info("GR00T N1.6 loaded from %s (direct)", model_path) def _load_n17(self, model_path: str, embodiment_tag: str, device: str): - """Load N1.7 — identical entry point to N1.6 (same ``Gr00tPolicy`` signature). + """Load N1.7 - identical entry point to N1.6 (same ``Gr00tPolicy`` signature). The user-visible policy class is still ``gr00t.policy.gr00t_policy.Gr00tPolicy``; internally it pulls the new Cosmos-Reason2-2B / Qwen3-VL backbone via @@ -537,9 +523,7 @@ def _load_n17(self, model_path: str, embodiment_tag: str, device: str): ) logger.info("GR00T N1.7 loaded from %s (direct)", model_path) - # ------------------------------------------------------------------ # Policy interface - # ------------------------------------------------------------------ @property def provider_name(self) -> str: @@ -553,9 +537,7 @@ async def get_actions(self, observation_dict: dict[str, Any], instruction: str, return self._local_get_actions(observation_dict, instruction) return self._service_get_actions(observation_dict, instruction) - # ------------------------------------------------------------------ - # Local inference — talks model's native nested-dict format - # ------------------------------------------------------------------ + # Local inference - talks model's native nested-dict format def _local_get_actions(self, robot_obs: dict[str, Any], instruction: str) -> list[dict[str, Any]]: """Local: prepare nested obs → infer → unpack actions.""" @@ -589,7 +571,7 @@ def _prepare_observation(self, robot_obs: dict[str, Any], instruction: str) -> d assert self._obs_mapping is not None, "Observation mapping not initialized" - # ── Video ── + # Video mapped_video_keys = set(self._obs_mapping.video.keys()) for robot_key, model_key in self._obs_mapping.video.items(): if robot_key in robot_obs: @@ -603,7 +585,7 @@ def _prepare_observation(self, robot_obs: dict[str, Any], instruction: str) -> d ref = _reference_video_shape(robot_obs, mapped_video_keys) video_dict[model_key] = np.zeros((1, 1, *ref), dtype=np.uint8) - # ── State ── + # State for robot_key, model_key in self._obs_mapping.state.items(): if robot_key in robot_obs: state_dict[model_key] = _to_state_batch(robot_obs[robot_key]) @@ -619,11 +601,11 @@ def _prepare_observation(self, robot_obs: dict[str, Any], instruction: str) -> d state_dict[model_key] = np.zeros((1, 1, dof), dtype=np.float32) else: logger.debug( - "Skipping zero-fill for '%s' — DOF unknown", + "Skipping zero-fill for '%s' - DOF unknown", model_key, ) - # ── Language ── + # Language lang_key = self._obs_mapping.language_key language_dict = {lang_key: [[instruction]]} @@ -663,9 +645,7 @@ def _unpack_actions(self, raw_actions: dict) -> list[dict[str, Any]]: return actions - # ------------------------------------------------------------------ # Service inference - # ------------------------------------------------------------------ def _service_get_actions(self, robot_obs: dict[str, Any], instruction: str) -> list[dict[str, Any]]: """Service mode: build observation, call server, unpack.""" @@ -735,7 +715,7 @@ def _unpack_service_actions(self, action_chunk: dict) -> list[dict[str, Any]]: actions.append(step) return actions - # No mapping — return bare model keys + # No mapping - return bare model keys actions = [] for t in range(horizon): step = {} @@ -746,9 +726,7 @@ def _unpack_service_actions(self, action_chunk: dict) -> list[dict[str, Any]]: return actions -# --------------------------------------------------------------------------- -# Shape helpers — match Isaac-GR00T's expected formats exactly -# --------------------------------------------------------------------------- +# Shape helpers - match Isaac-GR00T's expected formats exactly def _to_video_batch(value: np.ndarray) -> np.ndarray: diff --git a/strands_robots/policies/lerobot_local/__init__.py b/strands_robots/policies/lerobot_local/__init__.py index 7f1c281..c75ce8b 100644 --- a/strands_robots/policies/lerobot_local/__init__.py +++ b/strands_robots/policies/lerobot_local/__init__.py @@ -1,4 +1,4 @@ -"""LeRobot Local Policy — Direct HuggingFace model inference (no server needed).""" +"""LeRobot Local Policy - Direct HuggingFace model inference (no server needed).""" from .policy import LerobotLocalPolicy diff --git a/strands_robots/policies/lerobot_local/policy.py b/strands_robots/policies/lerobot_local/policy.py index b8e5817..a9e9171 100644 --- a/strands_robots/policies/lerobot_local/policy.py +++ b/strands_robots/policies/lerobot_local/policy.py @@ -1,4 +1,4 @@ -"""LeRobot Local Policy — Direct HuggingFace model inference (no server needed). +"""LeRobot Local Policy - Direct HuggingFace model inference (no server needed). Uses LeRobot's own factory for auto-detection. Any model LeRobot supports, this policy supports. @@ -186,9 +186,7 @@ def set_robot_state_keys(self, robot_state_keys: list[str]) -> None: "Call set_robot_state_keys() with the robot's actual joint/motor names." ) - # ------------------------------------------------------------------ # Tokenizer resolution (VLA language token injection) - # ------------------------------------------------------------------ def _resolve_tokenizer(self) -> Any | None: """Resolve and cache the tokenizer for VLA language token injection. @@ -288,9 +286,7 @@ def _needs_language_tokens(self) -> bool: return False - # ------------------------------------------------------------------ # Model loading - # ------------------------------------------------------------------ def _load_model(self) -> None: """Load the LeRobot model from pretrained path. @@ -381,7 +377,7 @@ def _load_model(self) -> None: self._processor_bridge = None logger.debug("No processor configs found, using raw obs/action flow") except (FileNotFoundError, ValueError, ImportError) as exc: - # Processor bridge is optional — models work without it via raw obs/action flow. + # Processor bridge is optional - models work without it via raw obs/action flow. # Fail-fast only if the user explicitly requested processor overrides. if self.processor_overrides: raise RuntimeError( @@ -393,9 +389,7 @@ def _load_model(self) -> None: # Initialize RTC if supported by this policy self._init_rtc() - # ------------------------------------------------------------------ # Real-Time Chunking (RTC) support - # ------------------------------------------------------------------ def _init_rtc(self) -> None: """Initialize RTC if the loaded policy supports it. @@ -422,7 +416,7 @@ def _init_rtc(self) -> None: return # Auto-detect from model config. - # RTC requires rtc_config on the model — not just predict_action_chunk(). + # RTC requires rtc_config on the model - not just predict_action_chunk(). # In LeRobot 0.5+, predict_action_chunk() is a base class method that ALL # policies inherit (ACT, Diffusion, etc.), but only flow-matching policies # (Pi0, SmolVLA) have an rtc_config that parameterizes the denoiser for @@ -437,7 +431,7 @@ def _init_rtc(self) -> None: elif self._rtc_requested is True: if rtc_config is None: # User explicitly asked for RTC, but this policy has no rtc_config. - # This means it's not a flow-matching policy — warn and disable. + # This means it's not a flow-matching policy - warn and disable. logger.warning( "RTC requested but policy '%s' has no rtc_config. " "RTC is only supported by flow-matching policies (Pi0, SmolVLA). " @@ -509,7 +503,7 @@ def _predict_with_rtc(self, batch: dict[str, Any]) -> torch.Tensor: batch: Observation batch tensors ready for the policy. Returns: - Action tensor — first action(s) from the chunk, accounting for + Action tensor - first action(s) from the chunk, accounting for inference delay. """ inference_start = time.time() @@ -532,7 +526,7 @@ def _predict_with_rtc(self, batch: dict[str, Any]) -> torch.Tensor: if action_chunk.dim() == 3 and action_chunk.shape[0] == 1: action_chunk = action_chunk.squeeze(0) - # Estimate inference delay — how many steps were consumed while computing + # Estimate inference delay - how many steps were consumed while computing inference_delay = self._estimate_inference_delay() # Store leftover for next RTC call (unconsumed portion of this chunk) @@ -547,11 +541,11 @@ def _predict_with_rtc(self, batch: dict[str, Any]) -> torch.Tensor: else: self._rtc_prev_chunk = None - # Skip delay steps — they correspond to time spent during inference + # Skip delay steps - they correspond to time spent during inference usable_start = min(inference_delay, action_chunk.shape[0] - 1) usable_actions = action_chunk[usable_start:] - # Log RTC details at debug level — throttled to once every 2s regardless of Hz + # Log RTC details at debug level - throttled to once every 2s regardless of Hz _now = time.monotonic() if _now - self._rtc_last_log_time >= 2.0: self._rtc_last_log_time = _now @@ -566,9 +560,7 @@ def _predict_with_rtc(self, batch: dict[str, Any]) -> torch.Tensor: return usable_actions - # ------------------------------------------------------------------ # Inference - # ------------------------------------------------------------------ async def get_actions(self, observation_dict: dict[str, Any], instruction: str, **kwargs) -> list[dict[str, Any]]: """Get actions from policy given observation and instruction. @@ -637,9 +629,7 @@ async def get_actions(self, observation_dict: dict[str, Any], instruction: str, return self._tensor_to_action_dicts(action_tensor) - # ------------------------------------------------------------------ # Observation batch building - # ------------------------------------------------------------------ def _fixup_preprocessed_batch(self, batch: dict[str, Any]) -> dict[str, Any]: """Fix up a preprocessor-produced batch so every value is a proper batched tensor. @@ -666,7 +656,7 @@ def _fixup_preprocessed_batch(self, batch: dict[str, Any]) -> dict[str, Any]: fixed: dict[str, Any] = {} for key, val in batch.items(): - # --- numpy arrays → torch tensors --- + # numpy arrays → torch tensors if isinstance(val, np.ndarray): if "image" in key: # HWC uint8 → CHW float32 → (1,C,H,W) @@ -682,7 +672,7 @@ def _fixup_preprocessed_batch(self, batch: dict[str, Any]) -> dict[str, Any]: t = t.unsqueeze(0) # (D,) → (1,D) fixed[key] = t.to(device) - # --- torch tensors: ensure batch dim + device --- + # torch tensors: ensure batch dim + device elif isinstance(val, torch.Tensor): # Auto-cast float64 → float32: ROS/dynamixel drivers often produce float64 t = val.float() if val.dtype == torch.float64 else val @@ -695,7 +685,7 @@ def _fixup_preprocessed_batch(self, batch: dict[str, Any]) -> dict[str, Any]: t = t.unsqueeze(0) # (D,) → (1,D) fixed[key] = t.to(device) - # --- pass through anything else (strings, etc.) --- + # pass through anything else (strings, etc.) else: fixed[key] = val @@ -772,7 +762,7 @@ def _build_batch_from_lerobot_format( - Scalars → float32 tensor with batch dim Non-numeric types (strings, pre-batched int64 tokens) are passed through - unchanged — LeRobot expects these as-is for task descriptions and + unchanged - LeRobot expects these as-is for task descriptions and pre-tokenized inputs. Args: @@ -813,7 +803,7 @@ def _build_batch_from_lerobot_format( is_image = True if is_image and tensor.dim() == 3 and tensor.shape[-1] in (1, 3, 4): tensor = tensor.permute(2, 0, 1) - # uint8 images are [0, 255] — normalize to [0, 1] for model input + # uint8 images are [0, 255] - normalize to [0, 1] for model input if is_image and value.dtype == np.uint8: tensor = tensor / 255.0 if is_image and tensor.dim() == 3: @@ -829,7 +819,7 @@ def _build_batch_from_lerobot_format( try: array = np.array(value, dtype=np.float32) except (ValueError, TypeError): - # Non-numeric lists (e.g. string lists) — skip silently, they aren't tensor data + # Non-numeric lists (e.g. string lists) - skip silently, they aren't tensor data logger.debug("Skipping non-numeric list/tuple for key in observation batch") continue tensor = torch.from_numpy(array).float() @@ -868,7 +858,7 @@ def _build_batch_from_strands_format( """ if not self.robot_state_keys: raise ValueError( - "robot_state_keys is empty — cannot map observation to state tensor. " + "robot_state_keys is empty - cannot map observation to state tensor. " "Call set_robot_state_keys() with the robot's motor names." ) @@ -895,7 +885,7 @@ def _build_batch_from_strands_format( expected_dim = state_feature.shape[0] if hasattr(state_feature, "shape") else len(state_values) if len(state_values) > expected_dim: logger.warning( - "State dim %d > model expects %d — truncating to first %d values. " + "State dim %d > model expects %d - truncating to first %d values. " "Check that robot_state_keys matches your robot's actual joint count.", len(state_values), expected_dim, @@ -904,7 +894,7 @@ def _build_batch_from_strands_format( state_values = state_values[:expected_dim] elif len(state_values) < expected_dim: logger.warning( - "State dim %d < model expects %d — zero-padding with %d zeros. " + "State dim %d < model expects %d - zero-padding with %d zeros. " "Check that robot_state_keys matches your robot's actual joint count.", len(state_values), expected_dim, @@ -936,9 +926,7 @@ def _build_batch_from_strands_format( return batch - # ------------------------------------------------------------------ # Action conversion - # ------------------------------------------------------------------ def _tensor_to_action_dicts(self, action_tensor: torch.Tensor) -> list[dict[str, Any]]: """Convert action tensor to list of robot action dicts. diff --git a/strands_robots/policies/lerobot_local/processor.py b/strands_robots/policies/lerobot_local/processor.py index 30362b6..7b27237 100644 --- a/strands_robots/policies/lerobot_local/processor.py +++ b/strands_robots/policies/lerobot_local/processor.py @@ -127,7 +127,7 @@ def from_pretrained( ) logger.info("Loaded preprocessor from %s: %d steps", pretrained_name_or_path, len(preprocessor)) except (FileNotFoundError, ValueError) as exc: - # No config file found — model doesn't ship a preprocessor. This is normal. + # No config file found - model doesn't ship a preprocessor. This is normal. logger.debug("No preprocessor found: %s", exc) # Load postprocessor @@ -139,7 +139,7 @@ def from_pretrained( ) logger.info("Loaded postprocessor from %s: %d steps", pretrained_name_or_path, len(postprocessor)) except (FileNotFoundError, ValueError) as exc: - # No config file found — model doesn't ship a postprocessor. This is normal. + # No config file found - model doesn't ship a postprocessor. This is normal. logger.debug("No postprocessor found: %s", exc) return cls( diff --git a/strands_robots/policies/lerobot_local/resolution.py b/strands_robots/policies/lerobot_local/resolution.py index 401b666..9783e90 100644 --- a/strands_robots/policies/lerobot_local/resolution.py +++ b/strands_robots/policies/lerobot_local/resolution.py @@ -35,7 +35,7 @@ def _ensure_policy_configs_registered() -> None: each config module has module-level side effects that populate the registry. This function imports a single known config to bootstrap the entire registry. - It's safe to call multiple times — the import is idempotent. + It's safe to call multiple times - the import is idempotent. """ global _CONFIGS_REGISTERED if _CONFIGS_REGISTERED: @@ -49,7 +49,7 @@ def _ensure_policy_configs_registered() -> None: _CONFIGS_REGISTERED = True logger.debug("LeRobot policy configs registered in draccus choice registry") except (ImportError, ModuleNotFoundError): - # Pre-0.5 lerobot or missing policy subpackage — that's OK, + # Pre-0.5 lerobot or missing policy subpackage - that's OK, # the caller will fall through to manual resolution. logger.debug("Could not import lerobot policy configs for draccus registration") except Exception as exc: @@ -97,7 +97,7 @@ class lookup, and weight loading via the draccus config registry. logger.debug("PreTrainedConfig resolution failed, trying manual: %s", exc) except Exception as exc: # draccus raises DecodingError/ParsingError which are NOT subclasses - # of RuntimeError/ValueError — they inherit from DraccusException → Exception. + # of RuntimeError/ValueError - they inherit from DraccusException → Exception. # Catch broadly here but only for draccus-related errors. if "draccus" in type(exc).__module__ or "DecodingError" in type(exc).__name__: logger.debug("PreTrainedConfig draccus error, trying manual: %s", exc) @@ -124,7 +124,7 @@ def _ensure_lerobot_policies_importable() -> None: its ``__init__.py``. LeRobot 0.5+ has a ``lerobot/policies/__init__.py`` that eagerly imports - **all** policy packages (groot, act, diffusion, …). The groot import chain + **all** policy packages (groot, act, diffusion, ...). The groot import chain pulls in ``transformers`` → ``flash_attn`` which can crash at module load time on environments with ABI mismatches (e.g. wrong torch / flash-attn version combo). @@ -145,7 +145,7 @@ def _ensure_lerobot_policies_importable() -> None: key = "lerobot.policies" if key in sys.modules: - # Already imported (successfully or via a previous stub) — nothing to do. + # Already imported (successfully or via a previous stub) - nothing to do. return try: @@ -234,7 +234,7 @@ def resolve_policy_class_by_name(policy_type: str) -> type[Any]: except (ImportError, AttributeError, RuntimeError): pass - # Strategy 4: PreTrainedPolicy — only if it's NOT abstract + # Strategy 4: PreTrainedPolicy - only if it's NOT abstract try: from lerobot.policies.pretrained import PreTrainedPolicy diff --git a/strands_robots/policies/mock.py b/strands_robots/policies/mock.py index e4fa38c..6c97b02 100644 --- a/strands_robots/policies/mock.py +++ b/strands_robots/policies/mock.py @@ -1,4 +1,4 @@ -"""Mock policy for testing — generates smooth sinusoidal trajectories.""" +"""Mock policy for testing - generates smooth sinusoidal trajectories.""" import logging import math @@ -10,7 +10,7 @@ class MockPolicy(Policy): - """Mock policy for testing — generates smooth sinusoidal trajectories.""" + """Mock policy for testing - generates smooth sinusoidal trajectories.""" def __init__(self, **kwargs: Any) -> None: self.robot_state_keys: list[str] = [] @@ -21,6 +21,11 @@ def __init__(self, **kwargs: Any) -> None: def provider_name(self) -> str: return "mock" + @property + def requires_images(self) -> bool: + """Mock policy only consumes joint state - skip camera rendering.""" + return False + def set_robot_state_keys(self, robot_state_keys: list[str]) -> None: self.robot_state_keys = robot_state_keys diff --git a/strands_robots/registry/__init__.py b/strands_robots/registry/__init__.py index 2d6ba3d..ef36cb5 100644 --- a/strands_robots/registry/__init__.py +++ b/strands_robots/registry/__init__.py @@ -1,4 +1,4 @@ -"""Unified Registry — single source of truth for robots and policies. +"""Unified Registry - single source of truth for robots and policies. Loads robot definitions and policy provider configs from JSON files. @@ -6,7 +6,7 @@ - **One file to edit**: Add a robot → edit robots.json, done. - **Hot-reload**: JSON is re-read when the file changes (mtime check). - **Self-contained entries**: Each robot/policy owns its aliases, - shorthands, and URL patterns — no separate lookup tables. + shorthands, and URL patterns - no separate lookup tables. - **Validation**: Duplicate aliases, shorthands, and URL patterns are caught on load with clear error messages. diff --git a/strands_robots/registry/policies.py b/strands_robots/registry/policies.py index 07ea6f2..c525b90 100644 --- a/strands_robots/registry/policies.py +++ b/strands_robots/registry/policies.py @@ -1,4 +1,4 @@ -"""Policy registry — resolve, import, and configure policy providers. +"""Policy registry - resolve, import, and configure policy providers. All provider definitions live in policies.json. This module provides the public read API for resolving smart policy strings, importing provider @@ -63,7 +63,7 @@ def resolve_policy(policy: str, **extra_kwargs) -> tuple[str, dict[str, Any]]: 5. Fallback to lerobot_local Args: - policy: Smart string — HF model ID, URL, or provider name. + policy: Smart string - HF model ID, URL, or provider name. **extra_kwargs: Additional kwargs merged into result. Returns: @@ -84,7 +84,7 @@ def resolve_policy(policy: str, **extra_kwargs) -> tuple[str, dict[str, Any]]: policy = policy.strip() kwargs: dict[str, Any] = {} - # 1. URL pattern matching — check each provider's url_patterns + # 1. URL pattern matching - check each provider's url_patterns for prov_name, prov_info in providers.items(): for pattern in prov_info.get("url_patterns", []): if re.match(pattern, policy): @@ -105,7 +105,7 @@ def resolve_policy(policy: str, **extra_kwargs) -> tuple[str, dict[str, Any]]: kwargs.update(extra_kwargs) return prov_name, kwargs - # 2. Shorthand names — built from each provider's shorthands list + # 2. Shorthand names - built from each provider's shorthands list alias_map = _build_alias_map() if policy.lower() in alias_map: kwargs.update(extra_kwargs) diff --git a/strands_robots/registry/robots.py b/strands_robots/registry/robots.py index 779f5fa..81d873b 100644 --- a/strands_robots/registry/robots.py +++ b/strands_robots/registry/robots.py @@ -1,4 +1,4 @@ -"""Robot registry — query, resolve, and list robot definitions. +"""Robot registry - query, resolve, and list robot definitions. All robot definitions live in robots.json. This module provides the public read API; the JSON file is the only thing you edit to add @@ -55,7 +55,7 @@ def get_robot(name: str) -> dict[str, Any] | None: Returns: Robot dict with keys like description, category, joints, asset, - hardware — or None if not found. + hardware - or None if not found. """ reg = _load("robots") canonical = resolve_name(name) @@ -92,7 +92,7 @@ def list_robots(mode: str = "all") -> list[dict[str, Any]]: """List available robots, optionally filtered. Args: - mode: Filter — "all", "sim", "real", or "both" (has sim AND real). + mode: Filter - "all", "sim", "real", or "both" (has sim AND real). Returns: List of dicts with name, description, has_sim, has_real. @@ -137,19 +137,54 @@ def list_aliases() -> dict[str, str]: return _build_alias_map() -def format_robot_table() -> str: - """Human-readable table of all robots for CLI/tool output.""" - lines = [ - f"{'Name':<20} {'Category':<15} {'Joints':<8} {'Sim':<5} {'Real':<5} Description", - "─" * 100, - ] - for cat in ["arm", "bimanual", "hand", "humanoid", "expressive", "mobile", "mobile_manip"]: +_NAME_WIDTH = 20 +_CAT_WIDTH = 15 +_JOINTS_WIDTH = 8 +_SIM_WIDTH = 5 +_REAL_WIDTH = 5 +# Width of the fixed prefix columns, including single-space separators. +_FIXED_PREFIX_WIDTH = _NAME_WIDTH + 1 + _CAT_WIDTH + 1 + _JOINTS_WIDTH + 1 + _SIM_WIDTH + 1 + _REAL_WIDTH + 1 + + +def format_robot_table(max_width: int = 100) -> str: + """Human-readable table of all robots for CLI/tool output. + + Args: + max_width: Target terminal width. The ``Description`` column is + truncated with an ellipsis to fit. Pass a large value (e.g. + ``1000``) to disable truncation entirely. Default 100 is safe + for a typical 100-column terminal. + """ + desc_width = max(20, max_width - _FIXED_PREFIX_WIDTH) + + header = ( + f"{'Name':<{_NAME_WIDTH}} " + f"{'Category':<{_CAT_WIDTH}} " + f"{'Joints':<{_JOINTS_WIDTH}} " + f"{'Sim':<{_SIM_WIDTH}} " + f"{'Real':<{_REAL_WIDTH}} " + f"Description" + ) + rule_width = min(max(max_width, len(header)), _FIXED_PREFIX_WIDTH + desc_width) + lines = [header, "─" * rule_width] + + for cat in ["arm", "bimanual", "hand", "humanoid", "expressive", "mobile", "mobile_manip", "aerial"]: by_cat = list_robots_by_category() for r in by_cat.get(cat, []): sim = "✅" if r["has_sim"] else " " real = "✅" if r["has_real"] else " " joints = str(r["joints"]) if r["joints"] else "?" - lines.append(f"{r['name']:<20} {r['category']:<15} {joints:<8} {sim:<5} {real:<5} {r['description']}") + desc = r["description"] or "" + if len(desc) > desc_width: + desc = desc[: desc_width - 3].rstrip() + "..." + lines.append( + f"{r['name']:<{_NAME_WIDTH}} " + f"{r['category']:<{_CAT_WIDTH}} " + f"{joints:<{_JOINTS_WIDTH}} " + f"{sim:<{_SIM_WIDTH}} " + f"{real:<{_REAL_WIDTH}} " + f"{desc}" + ) robots = list_robots() lines.append("") diff --git a/strands_robots/registry/user_registry.py b/strands_robots/registry/user_registry.py index eb55843..8f67f14 100644 --- a/strands_robots/registry/user_registry.py +++ b/strands_robots/registry/user_registry.py @@ -1,4 +1,4 @@ -"""User-local robot registry — runtime registration without editing package JSON. +"""User-local robot registry - runtime registration without editing package JSON. Provides ``register_robot()`` and ``unregister_robot()`` for adding custom robots that persist across sessions via a ``user_robots.json`` file stored @@ -13,7 +13,7 @@ user registry. Use ``STRANDS_BASE_DIR`` to relocate user metadata. At load time the user overlay is merged *on top of* the package -``robots.json`` — user entries win on name collision, so you can also +``robots.json`` - user entries win on name collision, so you can also override built-in robots locally. Usage:: @@ -34,7 +34,7 @@ from strands_robots.simulation import create_simulation sim = create_simulation() sim.create_world() - sim.add_robot("my_arm") # ✅ auto-resolved + sim.add_robot("my_arm") # auto-resolved # Remove it unregister_robot("my_arm") @@ -176,7 +176,7 @@ def register_robot( if _pkg_get_robot(name) is not None: logger.info( - "Robot '%s' exists in package registry — user registration will override it.", + "Robot '%s' exists in package registry - user registration will override it.", name, ) except ImportError: @@ -189,7 +189,7 @@ def register_robot( # This matches how resolve_model_path works: search_dir / asset["dir"] / xml dir_name = resolved_dir.name - # Alias collision detection — warn (don't fail) when a user alias shadows a + # Alias collision detection - warn (don't fail) when a user alias shadows a # canonical name or another alias. Doing this at registration surfaces the # problem immediately instead of at silent resolution-order time. if aliases and not overwrite: @@ -218,7 +218,7 @@ def register_robot( logger.warning("Alias %r is already used by another robot.", alias) # Validate model_xml exists. Previously we only checked when - # ``resolved_dir`` existed — which silently accepted registrations for + # ``resolved_dir`` existed - which silently accepted registrations for # dirs that didn't exist yet and surfaced a confusing error only at # ``add_robot()`` time. Now we fail-closed on both conditions so the # user gets an immediate, actionable error at registration time. @@ -283,7 +283,7 @@ def unregister_robot(name: str) -> bool: data = _load_user_registry() if name not in data.get("robots", {}): - logger.info("Robot '%s' not in user registry — nothing to remove.", name) + logger.info("Robot '%s' not in user registry - nothing to remove.", name) return False del data["robots"][name] diff --git a/strands_robots/robot.py b/strands_robots/robot.py index 33c5ba7..3927afe 100644 --- a/strands_robots/robot.py +++ b/strands_robots/robot.py @@ -236,7 +236,7 @@ async def _connect_robot(self) -> tuple[bool, str]: # Check if already connected if self.robot.is_connected: - logger.info(f"✅ {self.robot} already connected") + logger.info(f"{self.robot} already connected") return True, "" logger.info(f"🔌 Connecting to {self.robot}...") @@ -248,13 +248,13 @@ async def _connect_robot(self) -> tuple[bool, str]: except DeviceAlreadyConnectedError: # This is expected and fine - robot is already connected - logger.info(f"✅ {self.robot} was already connected") + logger.info(f"{self.robot} was already connected") except Exception as e: # Check if it's the string version of "already connected" error error_str = str(e).lower() if "already connected" in error_str or "is already connected" in error_str: - logger.info(f"✅ {self.robot} connection already established") + logger.info(f"{self.robot} connection already established") else: # Re-raise if it's a different error raise e @@ -262,7 +262,7 @@ async def _connect_robot(self) -> tuple[bool, str]: # Final connection check if not self.robot.is_connected: error_msg = f"Failed to connect to {self.robot}" - logger.error(f"❌ {error_msg}") + logger.error(f"{error_msg}") return False, error_msg # Check robot calibration @@ -271,15 +271,15 @@ async def _connect_robot(self) -> tuple[bool, str]: f"Robot {self.robot} is not calibrated. Please calibrate the robot manually" " first using LeRobot's calibration process (lerobot-calibrate)" ) - logger.error(f"❌ {error_msg}") + logger.error(f"{error_msg}") return False, error_msg - logger.info(f"✅ {self.robot} connected and ready") + logger.info(f"{self.robot} connected and ready") return True, "" except Exception as e: error_msg = f"Robot connection failed: {e}. Ensure robot is calibrated and accessible on the specified port" - logger.error(f"❌ {error_msg}") + logger.error(f"{error_msg}") return False, error_msg async def _initialize_policy(self, policy: Policy) -> bool: @@ -300,7 +300,7 @@ async def _initialize_policy(self, policy: Policy) -> bool: return True except Exception as e: - logger.error(f"❌ Failed to initialize policy: {e}") + logger.error(f"Failed to initialize policy: {e}") return False async def _execute_task_async( @@ -370,12 +370,10 @@ async def _execute_task_async( if self._task_state.status == TaskStatus.RUNNING: self._task_state.status = TaskStatus.COMPLETED - logger.info( - f"✅ Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)" - ) + logger.info(f"Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)") except Exception as e: - logger.error(f"❌ Task execution failed: {e}") + logger.error(f"Task execution failed: {e}") self._task_state.status = TaskStatus.ERROR self._task_state.error_message = str(e) @@ -415,12 +413,12 @@ async def task_runner(): "status": "success" if self._task_state.status == TaskStatus.COMPLETED else "error", "content": [ { - "text": f"✅ Task: '{instruction}' - {self._task_state.status.value}\n" - f"🤖 Robot: {self.tool_name_str} ({self.robot})\n" - f"🧠 Policy: {policy_provider} on {policy_host}:{policy_port}\n" - f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - f"🎯 Steps: {self._task_state.step_count}" - + (f"\n❌ Error: {self._task_state.error_message}" if self._task_state.error_message else "") + "text": f"Task: '{instruction}' - {self._task_state.status.value}\n" + f"Robot: {self.tool_name_str} ({self.robot})\n" + f"Policy: {policy_provider} on {policy_host}:{policy_port}\n" + f"Duration: {self._task_state.duration:.1f}s\n" + f"Steps: {self._task_state.step_count}" + + (f"\nError: {self._task_state.error_message}" if self._task_state.error_message else "") } ], } @@ -439,7 +437,7 @@ def start_task( if self._task_state.status == TaskStatus.RUNNING: return { "status": "error", - "content": [{"text": f"❌ Task already running: {self._task_state.instruction}"}], + "content": [{"text": f"Task already running: {self._task_state.instruction}"}], } # Start task in background @@ -451,10 +449,10 @@ def start_task( "status": "success", "content": [ { - "text": f"🚀 Task started: '{instruction}'\n" - f"🤖 Robot: {self.tool_name_str}\n" - f"💡 Use action='status' to check progress\n" - f"💡 Use action='stop' to interrupt" + "text": f"Task started: '{instruction}'\n" + f"Robot: {self.tool_name_str}\n" + f"Use action='status' to check progress\n" + f"Use action='stop' to interrupt" } ], } @@ -466,20 +464,20 @@ def get_task_status(self) -> dict[str, Any]: if self._task_state.status == TaskStatus.RUNNING: self._task_state.duration = time.time() - self._task_state.start_time - status_text = f"📊 Robot Status: {self._task_state.status.value.upper()}\n" + status_text = f"Robot Status: {self._task_state.status.value.upper()}\n" if self._task_state.instruction: - status_text += f"🎯 Task: {self._task_state.instruction}\n" + status_text += f"Task: {self._task_state.instruction}\n" if self._task_state.status == TaskStatus.RUNNING: - status_text += f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - status_text += f"🔄 Steps: {self._task_state.step_count}\n" + status_text += f"Duration: {self._task_state.duration:.1f}s\n" + status_text += f"Steps: {self._task_state.step_count}\n" elif self._task_state.status in [TaskStatus.COMPLETED, TaskStatus.STOPPED, TaskStatus.ERROR]: - status_text += f"⏱️ Total Duration: {self._task_state.duration:.1f}s\n" - status_text += f"🎯 Total Steps: {self._task_state.step_count}\n" + status_text += f"Total Duration: {self._task_state.duration:.1f}s\n" + status_text += f"Total Steps: {self._task_state.step_count}\n" if self._task_state.error_message: - status_text += f"❌ Error: {self._task_state.error_message}\n" + status_text += f"Error: {self._task_state.error_message}\n" return { "status": "success", @@ -502,15 +500,15 @@ def stop_task(self) -> dict[str, Any]: if self._task_state.task_future: self._task_state.task_future.cancel() - logger.info(f"🛑 Task stopped: {self._task_state.instruction}") + logger.info(f"Task stopped: {self._task_state.instruction}") return { "status": "success", "content": [ { - "text": f"🛑 Task stopped: '{self._task_state.instruction}'\n" - f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - f"🎯 Steps completed: {self._task_state.step_count}" + "text": f"Task stopped: '{self._task_state.instruction}'\n" + f"Duration: {self._task_state.duration:.1f}s\n" + f"Steps completed: {self._task_state.step_count}" } ], } @@ -601,7 +599,7 @@ async def stream( tool_use_id, { "status": "error", - "content": [{"text": "❌ instruction and policy_port are required for execute action"}], + "content": [{"text": "Instruction and policy_port are required for execute action"}], }, ) ) @@ -625,7 +623,7 @@ async def stream( tool_use_id, { "status": "error", - "content": [{"text": "❌ instruction and policy_port are required for start action"}], + "content": [{"text": "Instruction and policy_port are required for start action"}], }, ) ) @@ -652,20 +650,20 @@ async def stream( { "status": "error", "content": [ - {"text": f"❌ Unknown action: {action}. Valid actions: execute, start, status, stop"} + {"text": f"Unknown action: {action}. Valid actions: execute, start, status, stop"} ], }, ) ) except Exception as e: - logger.error(f"❌ {self.tool_name_str} error: {e}") + logger.error(f"{self.tool_name_str} error: {e}") yield ToolResultEvent( self._make_tool_result( tool_use_id, { "status": "error", - "content": [{"text": f"❌ {self.tool_name_str} error: {str(e)}"}], + "content": [{"text": f"{self.tool_name_str} error: {str(e)}"}], }, ) ) @@ -686,7 +684,7 @@ def cleanup(self): logger.info(f"🧹 {self.tool_name_str} cleanup completed") except Exception as e: - logger.error(f"❌ Cleanup error for {self.tool_name_str}: {e}") + logger.error(f"Cleanup error for {self.tool_name_str}: {e}") def __del__(self): """Destructor to ensure cleanup.""" @@ -730,7 +728,7 @@ async def get_status(self) -> dict[str, Any]: return status_data except Exception as e: - logger.error(f"❌ Error getting status for {self.tool_name_str}: {e}") + logger.error(f"Error getting status for {self.tool_name_str}: {e}") return { "robot_name": self.tool_name_str, "error": str(e), @@ -752,7 +750,7 @@ async def stop(self): # Cleanup resources self.cleanup() - logger.info(f"🛑 {self.tool_name_str} stopped and disconnected") + logger.info(f"{self.tool_name_str} stopped and disconnected") except Exception as e: - logger.error(f"❌ Error stopping robot: {e}") + logger.error(f"Error stopping robot: {e}") diff --git a/strands_robots/simulation/__init__.py b/strands_robots/simulation/__init__.py index d9674a9..ed94069 100644 --- a/strands_robots/simulation/__init__.py +++ b/strands_robots/simulation/__init__.py @@ -1,15 +1,25 @@ -"""Strands Robots Simulation — multi-backend simulation framework. +"""Strands Robots Simulation - multi-backend simulation framework. Architecture:: simulation/ - ├── __init__.py ← this file (re-exports, lazy loading) - ├── base.py ← SimEngine ABC - ├── factory.py ← create_simulation() + backend registration - ├── models.py ← shared dataclasses (SimWorld, SimRobot, ...) - └── model_registry.py ← URDF/MJCF resolution (shared across backends) - - # MuJoCo backend added in subsequent PRs. + ├ __init__.py ← this file (re-exports, lazy loading) + ├ base.py ← SimEngine ABC + ├ factory.py ← create_simulation() + backend registration + ├ models.py ← shared dataclasses (SimWorld, SimRobot, ...) + ├ model_registry.py ← URDF/MJCF resolution (shared across backends) + └ mujoco/ ← MuJoCo CPU backend + ├ __init__.py + ├ backend.py ← lazy mujoco import + GL config + ├ mjcf_builder.py ← MJCF XML builder + ├ physics.py ← advanced physics (raycasting, jacobians, forces) + ├ scene_ops.py ← XML round-trip inject/eject + ├ rendering.py ← render RGB/depth, observations + ├ policy_runner.py ← run_policy, eval_policy, replay + ├ randomization.py ← domain randomization + ├ recording.py ← LeRobotDataset recording + ├ tool_spec.json ← AgentTool input schema + └ simulation.py ← Simulation (AgentTool orchestrator) Usage:: @@ -39,7 +49,7 @@ import importlib as _importlib from typing import Any -# --- Light imports (no heavy deps — stdlib + dataclasses only) --- +# Light imports (no heavy deps - stdlib + dataclasses only) from strands_robots.simulation.base import SimEngine from strands_robots.simulation.factory import ( create_simulation, @@ -62,10 +72,15 @@ TrajectoryStep, ) -# --- Heavy imports (lazy — loaded when mujoco backend is available) --- -# MuJoCo-specific lazy imports will be added when the mujoco/ subpackage -# is introduced. For now, only the lightweight foundation is available. -_LAZY_IMPORTS: dict[str, tuple[str, str]] = {} +# Heavy imports (lazy - need strands SDK + mujoco) +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "Simulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"), + "MuJoCoSimulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"), + "MJCFBuilder": ("strands_robots.simulation.mujoco.mjcf_builder", "MJCFBuilder"), + "_configure_gl_backend": ("strands_robots.simulation.mujoco.backend", "_configure_gl_backend"), + "_ensure_mujoco": ("strands_robots.simulation.mujoco.backend", "_ensure_mujoco"), + "_is_headless": ("strands_robots.simulation.mujoco.backend", "_is_headless"), +} __all__ = [ @@ -75,9 +90,9 @@ "create_simulation", "list_backends", "register_backend", - # Default backend alias (available when mujoco backend is installed) - # "Simulation", - # "MuJoCoSimulation", + # Default backend alias + "Simulation", + "MuJoCoSimulation", # Shared dataclasses "SimStatus", "SimRobot", @@ -85,8 +100,8 @@ "SimCamera", "SimWorld", "TrajectoryStep", - # MuJoCo builder (available when mujoco backend is installed) - # "MJCFBuilder", + # MuJoCo builder + "MJCFBuilder", # Model registry "register_urdf", "resolve_model", @@ -108,4 +123,4 @@ def __getattr__(name: str) -> Any: # NOTE: MuJoCo GL backend configuration lives in the top-level # strands_robots/__init__.py to ensure it runs before any `import mujoco`. -# Do NOT duplicate it here — see PR #86 for the canonical location. +# Do NOT duplicate it here - see PR #86 for the canonical location. diff --git a/strands_robots/simulation/base.py b/strands_robots/simulation/base.py index 7ca2098..71ee141 100644 --- a/strands_robots/simulation/base.py +++ b/strands_robots/simulation/base.py @@ -1,7 +1,7 @@ -"""Simulation ABC — backend-agnostic interface for all simulation engines. +"""Simulation ABC - backend-agnostic interface for all simulation engines. Every simulation backend (MuJoCo, Isaac, Newton) implements this interface. -Agent tools and the Robot() factory interact through these methods only — +Agent tools and the Robot() factory interact through these methods only - they never touch backend-specific APIs directly. Usage:: @@ -20,7 +20,10 @@ import logging from abc import ABC, abstractmethod -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from strands_robots.policies import Policy logger = logging.getLogger(__name__) @@ -29,18 +32,25 @@ class SimEngine(ABC): """Abstract base class for simulation engines. Defines the contract that all backends (MuJoCo, Isaac, Newton) must - implement. This is the *programmatic* API — the AgentTool layer + implement. This is the *programmatic* API - the AgentTool layer wraps it with tool_spec/stream for LLM access. Method categories: - **Required** (``@abstractmethod``): Core simulation loop — world - lifecycle, entity management, observation/action, rendering. Every - physics engine must implement these to be usable. + **Required** (``@abstractmethod``): Core simulation loop - world + lifecycle, entity management, observation/action, rendering, robot + discovery. Every physics engine must implement these to be usable. + + **Provided** (concrete base-class methods): Policy orchestration + (``run_policy`` / ``start_policy`` / ``replay_episode`` / ``eval_policy``) + is implemented once in this ABC as a facade over the abstract primitives. + Backends inherit them for free by implementing the primitives. They + *may* override for backend-specific optimisations (e.g. GPU-batched + policy inference on Isaac). **Optional** (default raises ``NotImplementedError``): Higher-level - features — scene loading, policy running, domain randomization, - contact queries. Backends opt in by overriding only what they support. + features - scene loading, domain randomization, contact queries. + Backends opt in by overriding only what they support. Lifecycle:: @@ -61,7 +71,7 @@ class SimEngine(ABC): sim.destroy() """ - # --- World lifecycle --- + # World lifecycle @abstractmethod def create_world( @@ -93,7 +103,7 @@ def get_state(self) -> dict[str, Any]: """Get full simulation state summary.""" ... - # --- Robot management --- + # Robot management @abstractmethod def add_robot( @@ -112,7 +122,26 @@ def remove_robot(self, name: str) -> dict[str, Any]: """Remove a robot from the simulation.""" ... - # --- Object management --- + @abstractmethod + def list_robots(self) -> list[str]: + """Return ordered list of robot names currently in the world. + + Used by the backend-agnostic ``PolicyRunner`` to resolve a + default robot when the caller omits ``robot_name``. + """ + ... + + @abstractmethod + def robot_joint_names(self, robot_name: str) -> list[str]: + """Return ordered joint names for ``robot_name``. + + Used by ``Policy.set_robot_state_keys`` and by + ``PolicyRunner.replay`` to map dataset action-vector indices to + named joints. Order must match the backend's action ordering. + """ + ... + + # Object management @abstractmethod def add_object( @@ -136,16 +165,39 @@ def remove_object(self, name: str) -> dict[str, Any]: """Remove an object from the scene.""" ... - # --- Observation / Action --- + # Observation / Action @abstractmethod - def get_observation(self, robot_name: str | None = None, camera_name: str | None = None) -> dict[str, Any]: - """Get observation from simulation. - - Convenience method that delegates to the underlying Robot - abstraction. Provides a unified interface for agent tools - that interact with simulation without needing to distinguish - between Robot and Sim layers. + def get_observation(self, robot_name: str | None = None, *, skip_images: bool = False) -> dict[str, Any]: + """Get full observation for a robot: joint state + all attached cameras. + + Unified observation consumed by :class:`Policy` and + :class:`~strands_robots.simulation.policy_runner.PolicyRunner`. + Backends MUST return a dict with the following schema; extra keys + are allowed. + + Schema: + - ``""`` (float): One entry per joint on the robot, + keyed by the *short* joint name (e.g. ``"shoulder_pan"``). + The schema is stable regardless of multi-robot namespacing + at the physics-engine level. + - ``""`` (np.ndarray): One RGB uint8 frame per + camera associated with the robot, keyed by camera name. + Shape ``(H, W, 3)``. Cameras whose render fails MAY be + omitted; joint state MUST still be returned. + + Single-camera rendering is :meth:`render`'s job, not this method's. + For batched multi-robot observation (future Isaac / Newton), add a + separate ``get_observations(robot_names)`` method - do NOT extend + this one. + + Args: + robot_name: Which robot to observe. If ``None`` and exactly one + robot exists, that robot is used; otherwise returns ``{}``. + + Returns: + Observation dict per schema above. Returns ``{}`` if the world + is not yet created or ``robot_name`` is unknown. """ ... @@ -157,10 +209,14 @@ def send_action(self, action: dict[str, Any], robot_name: str | None = None, n_s abstraction. The simulation engine acts as a facade so agent tools can use ``sim.send_action()`` without knowing about the Robot/Policy layer. + + Backends are responsible for internal thread-safety (e.g. + MuJoCo must acquire an internal lock here). ``PolicyRunner`` + does not manage locks. """ ... - # --- Rendering --- + # Rendering @abstractmethod def render( @@ -174,22 +230,241 @@ def render( """ ... - # --- Optional overrides (have default no-op implementations) --- + # Policy orchestration (concrete facade, not abstract) - def load_scene(self, scene_path: str) -> dict[str, Any]: - """Load a complete scene from file. Override per backend.""" - raise NotImplementedError("load_scene not implemented by this backend") + def run_policy( + self, + robot_name: str, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + duration: float = 10.0, + control_frequency: float = 50.0, + action_horizon: int = 8, + fast_mode: bool = False, + video: dict[str, Any] | None = None, + policy_object: Policy | None = None, + n_steps: int | None = None, + max_steps: int | None = None, + max_onframe_failures: int | None = None, + ) -> dict[str, Any]: + """Run a policy loop in the simulation (blocking). + + Default implementation delegates to the backend-agnostic + :class:`~strands_robots.simulation.policy_runner.PolicyRunner`. + Backends MAY override for backend-specific optimisations + (e.g. GPU-batched policy inference on Isaac). + + Args: + robot_name: Robot to control. + policy_provider: Name passed to + :func:`strands_robots.policies.create_policy`. + policy_config: Opaque dict of provider-specific kwargs + (``observation_mapping``, ``action_mapping``, ``host``, + ``port``, ``api_token``, ``pretrained_name_or_path``, + ``trust_remote_code``, ``actions_per_step``, + ``use_processor``, ``processor_overrides``, ``device``, + ...). Forwarded verbatim to ``create_policy``. + instruction: Natural-language instruction for the policy. + duration: Wall-clock seconds to run. + control_frequency: Target Hz for policy queries. + action_horizon: Max actions per policy call. + fast_mode: Skip real-time sleep between steps. + video: Optional video-recording config dict. Accepted keys: + ``path`` (str, output MP4 - required to enable recording), + ``fps`` (int, default 30), ``camera`` (str, default backend + default), ``width`` (int, default 640), ``height`` (int, + default 480). See :class:`~strands_robots.simulation.policy_runner.VideoConfig`. + For extension points beyond video (custom telemetry, + dataset recording), backends plug into + ``PolicyRunner.run``'s ``on_frame`` hook via + :meth:`_make_run_policy_hook`. + + Returns: + Standard status dict. + """ + from strands_robots.policies import create_policy + from strands_robots.simulation.policy_runner import PolicyRunner, VideoConfig + + # accept n_steps (or legacy max_steps) as an alternate horizon + # specification. duration = n_steps / control_frequency. If both + # are passed, n_steps wins (primary per DoD). + if n_steps is None and max_steps is not None: + n_steps = int(max_steps) + if n_steps is not None: + if n_steps <= 0: + return { + "status": "error", + "content": [{"text": f"run_policy: n_steps must be > 0, got {n_steps}."}], + } + if control_frequency <= 0: + return { + "status": "error", + "content": [{"text": "run_policy: control_frequency must be > 0 when n_steps is used."}], + } + duration = float(n_steps) / float(control_frequency) + + if robot_name not in self.list_robots(): + return { + "status": "error", + "content": [{"text": f"Robot '{robot_name}' not found."}], + } + + if policy_object is not None: + # Pre-built policy path - skip the expensive create_policy call. + # Caller is responsible for policy.set_robot_state_keys(...) if needed, + # but we set it here defensively so the semantics match the provider path. + policy = policy_object + else: + policy = create_policy(policy_provider, **(policy_config or {})) + policy.set_robot_state_keys(self.robot_joint_names(robot_name)) + + on_frame = self._make_run_policy_hook(robot_name, instruction) + + return PolicyRunner(self).run( + robot_name, + policy, + instruction=instruction, + duration=duration, + control_frequency=control_frequency, + action_horizon=action_horizon, + fast_mode=fast_mode, + video=VideoConfig.from_dict(video), + on_frame=on_frame, + max_onframe_failures=max_onframe_failures, + ) + + def start_policy( + self, + robot_name: str, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + duration: float = 10.0, + control_frequency: float = 50.0, + action_horizon: int = 8, + fast_mode: bool = False, + video: dict[str, Any] | None = None, + policy_object: Policy | None = None, + n_steps: int | None = None, + max_steps: int | None = None, + ) -> dict[str, Any]: + """Start policy execution in a background thread (non-blocking). - def run_policy(self, robot_name: str, policy_provider: str = "mock", **kwargs: Any) -> dict[str, Any]: - """Run a policy loop in the simulation. + Default implementation: synchronous passthrough to ``run_policy``. + Backends that support true background execution (like MuJoCo via + its ``ThreadPoolExecutor``) should override. - Orchestration shortcut: internally creates a Policy, then loops - ``obs → policy(obs) → send_action(action) → step()``. - Intentionally placed on SimEngine as a facade for agent tools - that need a single ``simulation(action="run_policy")`` interface. - Override per backend. + accepts ``n_steps`` (primary) or legacy ``max_steps`` as an + alternate to ``duration``. See ``run_policy`` for conversion rules. """ - raise NotImplementedError("run_policy not implemented by this backend") + return self.run_policy( + robot_name, + policy_provider=policy_provider, + policy_config=policy_config, + instruction=instruction, + duration=duration, + control_frequency=control_frequency, + action_horizon=action_horizon, + fast_mode=fast_mode, + video=video, + policy_object=policy_object, + n_steps=n_steps, + max_steps=max_steps, + ) + + def replay_episode( + self, + repo_id: str, + robot_name: str | None = None, + episode: int = 0, + root: str | None = None, + speed: float = 1.0, + action_key_map: list[str] | None = None, + ) -> dict[str, Any]: + """Replay a LeRobotDataset episode via ``PolicyRunner.replay``. + + Override per backend for optimised replay (e.g. direct ctrl + writes) only when measured necessary. + """ + from strands_robots.simulation.policy_runner import PolicyRunner + + return PolicyRunner(self).replay( + repo_id, + robot_name=robot_name, + episode=episode, + root=root, + speed=speed, + action_key_map=action_key_map, + ) + + def eval_policy( + self, + robot_name: str | None = None, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + n_episodes: int = 1, + max_steps: int = 300, + success_fn: str | None = None, + ) -> dict[str, Any]: + """Multi-episode policy evaluation via ``PolicyRunner.evaluate``. + + ``robot_name`` is required - eval_policy used to silently pick + the first robot, which is surprising in multi-robot scenes. + ``n_episodes`` default lowered from 10 to 1 (callers opt in to + longer evals explicitly). + """ + from strands_robots.policies import create_policy + from strands_robots.simulation.policy_runner import PolicyRunner + + if not robot_name: + return { + "status": "error", + "content": [{"text": "eval_policy requires 'robot_name'."}], + } + robots = self.list_robots() + if not robots: + return {"status": "error", "content": [{"text": "No robots in sim. Add one first."}]} + if robot_name not in robots: + return { + "status": "error", + "content": [{"text": f"Robot '{robot_name}' not found."}], + } + resolved_robot = robot_name + + policy = create_policy(policy_provider, **(policy_config or {})) + policy.set_robot_state_keys(self.robot_joint_names(resolved_robot)) + + return PolicyRunner(self).evaluate( + resolved_robot, + policy, + instruction=instruction, + n_episodes=n_episodes, + max_steps=max_steps, + success_fn=success_fn, + ) + + def _make_run_policy_hook(self, robot_name: str, instruction: str) -> Any: + """Override to return an ``on_frame(step, obs, action)`` callable. + + Used by backends that want to layer in recording / telemetry + without subclassing :class:`PolicyRunner`. Default: no hook. + + Args: + robot_name: Robot being controlled this run. + instruction: Instruction passed to this run. + + Returns: + Callable or ``None``. + """ + return None + + # Optional overrides (have default no-op implementations) + + def load_scene(self, scene_path: str) -> dict[str, Any]: + """Load a complete scene from file. Override per backend.""" + raise NotImplementedError("load_scene not implemented by this backend") def randomize(self, **kwargs: Any) -> dict[str, Any]: """Apply domain randomization. @@ -217,6 +492,6 @@ def __del__(self) -> None: try: self.cleanup() except Exception as e: - # Best-effort cleanup during GC — exceptions can't propagate + # Best-effort cleanup during GC - exceptions can't propagate # from __del__ (CPython ignores them), so log for visibility. logger.warning("Cleanup error during __del__: %s", e) diff --git a/strands_robots/simulation/factory.py b/strands_robots/simulation/factory.py index e7b0a5b..2d6ab03 100644 --- a/strands_robots/simulation/factory.py +++ b/strands_robots/simulation/factory.py @@ -1,4 +1,4 @@ -"""Simulation factory — create_simulation() and runtime backend registration. +"""Simulation factory - create_simulation() and runtime backend registration. Mirrors the policy factory pattern: JSON-driven defaults with runtime override capability. Backends are lazy-loaded on first use. @@ -34,9 +34,7 @@ logger = logging.getLogger(__name__) -# ───────────────────────────────────────────────────────────────────── -# Built-in backend registry (lazy loaders — no imports at module load) -# ───────────────────────────────────────────────────────────────────── +# Built-in backend registry (lazy loaders - no imports at module load) _BUILTIN_BACKENDS: dict[str, tuple[str, str]] = { "mujoco": ( @@ -59,9 +57,7 @@ DEFAULT_BACKEND = "mujoco" -# ───────────────────────────────────────────────────────────────────── # Runtime registration (for user-defined backends not in built-ins) -# ───────────────────────────────────────────────────────────────────── _runtime_registry: dict[str, Callable[[], type[SimEngine]]] = {} _runtime_aliases: dict[str, str] = {} diff --git a/strands_robots/simulation/model_registry.py b/strands_robots/simulation/model_registry.py index b7af5e9..6a077ed 100644 --- a/strands_robots/simulation/model_registry.py +++ b/strands_robots/simulation/model_registry.py @@ -1,11 +1,11 @@ -"""Robot model resolution — URDF registry + asset manager. +"""Robot model resolution - URDF registry + asset manager. Bridges the robot registry with actual URDF/MJCF files on disk. Resolution order for :func:`resolve_model`: 1. User-registered URDFs (:func:`register_urdf`) 2. URDF search paths (``STRANDS_ASSETS_DIR``, CWD, etc.) - 3. Asset manager (``robot_descriptions`` — fallback for standard robots) + 3. Asset manager (``robot_descriptions`` - fallback for standard robots) """ from __future__ import annotations @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) # URDF search paths are resolved lazily via :func:`strands_robots.utils.get_search_paths` -# at every lookup — this avoids snapshotting ``Path.cwd()`` and ``STRANDS_ASSETS_DIR`` +# at every lookup - this avoids snapshotting ``Path.cwd()`` and ``STRANDS_ASSETS_DIR`` # at import time, which caused silent wrong-path bugs when tests/notebooks chdir after # import. @@ -39,7 +39,7 @@ except ImportError: _HAS_REGISTRY = False -# Logged lazily on first resolution via _log_configuration_once() — +# Logged lazily on first resolution via _log_configuration_once() - # avoids noisy INFO on every ``import strands_robots``. _CONFIG_LOGGED = False @@ -68,7 +68,7 @@ def resolve_model(name: str, prefer_scene: bool = True) -> str | None: Resolution order (local assets take priority): 1. User-registered URDFs (custom user registrations) 2. URDF search paths (STRANDS_ASSETS_DIR, CWD, etc.) - 3. Asset manager (robot_descriptions — fallback for standard robots) + 3. Asset manager (robot_descriptions - fallback for standard robots) """ _log_configuration_once() # 1+2. Check local/custom paths first (user overrides win) @@ -92,7 +92,7 @@ def resolve_model(name: str, prefer_scene: bool = True) -> str | None: def resolve_urdf(data_config: str) -> str | None: """Resolve a data_config name to a URDF file path. - Also checks the registry's ``legacy_urdf`` field — a backward-compatible + Also checks the registry's ``legacy_urdf`` field - a backward-compatible path for robots that were registered before the MJCF asset system was introduced (e.g. robots originally configured with raw URDF paths). """ @@ -137,6 +137,6 @@ def list_available_models() -> str: lines = ["Registered URDFs:"] for name, path in _URDF_REGISTRY.items(): resolved = resolve_urdf(name) - status = "✅" if resolved else "❌" - lines.append(f" {status} {name}: {path}") + status = "[OK]" if resolved else "[MISSING]" + lines.append(f"{status} {name}: {path}") return "\n".join(lines) diff --git a/strands_robots/simulation/models.py b/strands_robots/simulation/models.py index e339d15..d5282a2 100644 --- a/strands_robots/simulation/models.py +++ b/strands_robots/simulation/models.py @@ -104,14 +104,14 @@ class SimWorld: escape hatches, each with a distinct role so backend implementers know which to use: - * ``_model``: the physics engine's **core model handle** — the single + * ``_model``: the physics engine's **core model handle** - the single compiled/loaded representation of the scene (e.g. ``mujoco.MjModel``, Isaac's ``Scene``, PyBullet's body registry). Every backend has one. - * ``_data``: the physics engine's **core simulation state handle** — + * ``_data``: the physics engine's **core simulation state handle** - the mutable per-step state companion to ``_model`` (e.g. ``mujoco.MjData``, Isaac's ``World``). Every backend has one. * ``_backend_state``: a **catch-all dict** for everything else the - backend needs to persist — generated XML, temp dirs, recording + backend needs to persist - generated XML, temp dirs, recording buffers, caches, etc. Prefer this over adding new fields here. All three are typed ``Any``/``dict`` so nothing leaks engine-specific @@ -127,7 +127,7 @@ class SimWorld: status: SimStatus = SimStatus.IDLE sim_time: float = 0.0 step_count: int = 0 - # Engine core handles — set after the backend builds the world. + # Engine core handles - set after the backend builds the world. # Use these for the primary model/state objects only; put everything # else in ``_backend_state`` below. _model: Any = None # Engine-specific model handle (e.g. MjModel, Scene) @@ -138,6 +138,6 @@ class SimWorld: # Prefer this over adding new fields to ``SimWorld``. _backend_state: dict[str, Any] = field(default_factory=dict) # Physics state checkpoints (used by save_state/restore_state in PR #85). - # Kept as a top-level field — requested by @yinsong1986 during review to + # Kept as a top-level field - requested by @yinsong1986 during review to # avoid monkey-patching when ``reset()`` creates a fresh ``SimWorld``. _checkpoints: dict[str, Any] = field(default_factory=dict) diff --git a/strands_robots/simulation/mujoco/__init__.py b/strands_robots/simulation/mujoco/__init__.py new file mode 100644 index 0000000..03c6a03 --- /dev/null +++ b/strands_robots/simulation/mujoco/__init__.py @@ -0,0 +1,32 @@ +"""MuJoCo simulation backend for strands-robots. + +CPU-based physics with offscreen rendering. No GPU required. +Supports URDF/MJCF loading, multi-robot scenes, policy execution, +domain randomization, and LeRobotDataset recording. + +Usage:: + + from strands_robots.simulation.mujoco import MuJoCoSimulation + + sim = MuJoCoSimulation(tool_name="my_sim") + sim.create_world() + sim.add_robot("so100", data_config="so100") + sim.run_policy("so100", policy_provider="mock", instruction="wave") + +Or via the top-level alias:: + + from strands_robots.simulation import Simulation # → MuJoCoSimulation +""" + +__all__ = [ + "MuJoCoSimulation", +] + + +def __getattr__(name: str) -> "type": + if name == "MuJoCoSimulation": + from strands_robots.simulation.mujoco.simulation import Simulation as _Sim + + globals()["MuJoCoSimulation"] = _Sim + return _Sim + raise AttributeError(f"module 'strands_robots.simulation.mujoco' has no attribute {name!r}") diff --git a/strands_robots/simulation/mujoco/backend.py b/strands_robots/simulation/mujoco/backend.py new file mode 100644 index 0000000..fc0f0d4 --- /dev/null +++ b/strands_robots/simulation/mujoco/backend.py @@ -0,0 +1,156 @@ +"""MuJoCo lazy import and GL backend configuration.""" + +import ctypes +import logging +import os +import sys +from typing import Any + +logger = logging.getLogger(__name__) + +_mujoco = None +_mujoco_viewer = None + + +def _is_headless() -> bool: + """Detect if running in a headless environment (no display server). + + Returns True on Linux when no DISPLAY or WAYLAND_DISPLAY is set, + which means GLFW-based rendering will fail. + + Windows and macOS are always False because MuJoCo uses native + windowing backends (WGL on Windows, CGL on macOS) that support + offscreen rendering without X11/Wayland. The EGL/OSMesa fallback + is Linux-specific. + """ + if sys.platform != "linux": + return False + if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"): + return False + return True + + +def _configure_gl_backend() -> None: # noqa: C901 + """Auto-configure MuJoCo's OpenGL backend for headless environments. + + MuJoCo reads MUJOCO_GL at import time to select the OpenGL backend: + - "egl" → EGL (GPU-accelerated offscreen, requires libEGL + NVIDIA driver) + - "osmesa" → OSMesa (CPU software rendering, slower but always works) + - "glfw" → GLFW (default, requires X11/Wayland display server) + + This function MUST be called before `import mujoco`. Setting MUJOCO_GL + after import has no effect - the backend is locked at import time. + + Never overrides a user-set MUJOCO_GL value. + """ + if os.environ.get("MUJOCO_GL"): + logger.debug(f"MUJOCO_GL already set to '{os.environ['MUJOCO_GL']}', respecting user config") + return + + if not _is_headless(): + return + + # Headless Linux - probe for EGL first (GPU-accelerated), then fall back to OSMesa (CPU) + try: + ctypes.cdll.LoadLibrary("libEGL.so.1") + os.environ["MUJOCO_GL"] = "egl" + logger.info("Headless environment detected - using MUJOCO_GL=egl (GPU-accelerated offscreen)") + return + except OSError: + pass + + try: + ctypes.cdll.LoadLibrary("libOSMesa.so") + os.environ["MUJOCO_GL"] = "osmesa" + logger.info("Headless environment detected - using MUJOCO_GL=osmesa (CPU software rendering)") + return + except OSError: + pass + + logger.warning( + "Headless environment detected but neither EGL nor OSMesa found. " + "MuJoCo rendering will likely fail. Install one of:\n" + " GPU: apt-get install libegl1-mesa-dev (or NVIDIA driver provides libEGL)\n" + " CPU: apt-get install libosmesa6-dev\n" + "Then set: export MUJOCO_GL=egl (or osmesa)" + ) + + +def _ensure_mujoco() -> "Any": + """Lazy import MuJoCo to avoid hard dependency. + + Auto-configures the OpenGL backend for headless environments before + importing mujoco, since MUJOCO_GL must be set at import time. + + Uses require_optional() for consistent dependency management across + the strands-robots package. + """ + global _mujoco, _mujoco_viewer + if _mujoco is None: + _configure_gl_backend() + from strands_robots.utils import require_optional + + _mujoco = require_optional( + "mujoco", + pip_install="mujoco", + extra="sim-mujoco", + purpose="MuJoCo simulation", + ) + if _mujoco_viewer is None and not _is_headless(): + try: + import mujoco.viewer as viewer + + _mujoco_viewer = viewer + except ImportError: + pass + return _mujoco + + +_rendering_available: bool | None = None + + +def _can_render() -> bool: + """Check if MuJoCo offscreen rendering is available. + + Probes once by creating a minimal Renderer. Result is cached. + Returns False on headless environments without EGL/OSMesa. + + On headless Linux, if MUJOCO_GL is not set after _configure_gl_backend() + ran, it means neither EGL nor OSMesa is available. In that case the + default GLFW backend would be used, which calls glfw.init() → abort() + at the C level (SIGABRT), killing the entire process before Python can + catch the error. We short-circuit to False to avoid the fatal probe. + """ + global _rendering_available + if _rendering_available is not None: + return _rendering_available + + # Guard: on headless systems without an offscreen GL backend configured, + # mj.Renderer() will use GLFW which triggers a C-level abort (SIGABRT). + # Skip the probe entirely - rendering is impossible anyway. + if _is_headless() and not os.environ.get("MUJOCO_GL"): + _rendering_available = False + logger.warning( + "Headless environment without EGL/OSMesa - rendering disabled. " + "Physics and joint observations will still work. " + "Install libegl1-mesa-dev or libosmesa6-dev for camera rendering." + ) + return False + + mj = _ensure_mujoco() + try: + model = mj.MjModel.from_xml_string("") + renderer = mj.Renderer(model, height=1, width=1) + renderer.close() + del renderer + _rendering_available = True + logger.info("MuJoCo rendering available") + except Exception as e: + _rendering_available = False + logger.warning( + "MuJoCo rendering unavailable: %s. " + "Physics/policy will work, but render/camera observations will be skipped. " + "Install EGL or OSMesa for offscreen rendering.", + e, + ) + return _rendering_available diff --git a/strands_robots/simulation/mujoco/mjcf_builder.py b/strands_robots/simulation/mujoco/mjcf_builder.py new file mode 100644 index 0000000..8fcf310 --- /dev/null +++ b/strands_robots/simulation/mujoco/mjcf_builder.py @@ -0,0 +1,273 @@ +"""MJCF XML builder - programmatic scene construction.""" + +import logging +import os +import re +import subprocess +import tempfile + +from strands_robots.simulation.models import SimCamera, SimObject, SimRobot, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +_VALID_NAME_RE = re.compile(r"^[a-zA-Z0-9_][a-zA-Z0-9_.\-]{0,127}$") + + +def _sanitize_name(name: str) -> str: + """Validate and sanitize an object/body name for safe MJCF XML embedding. + + Raises ValueError if name contains characters that could cause XML injection. + """ + if not _VALID_NAME_RE.match(name): + raise ValueError(f"Invalid simulation name {name!r}: must match [a-zA-Z0-9_][a-zA-Z0-9_.\\-]{{0,127}}") + return name + + +def _camera_xyaxes_from_target( + position: list[float], + target: list[float], + up: tuple[float, float, float] = (0.0, 0.0, 1.0), +) -> str | None: + """compute MJCF ``xyaxes`` attribute so a camera looks at ``target``. + + MuJoCo cameras with ``mode='fixed'`` need an explicit orientation. Without + xyaxes/quat MuJoCo uses the default -Z look direction, so ``add_camera``'s + ``target`` was completely ignored - every custom camera rendered the + default view and three cameras at different positions produced byte- + identical near-black PNGs. + + MJCF xyaxes format: "x0 x1 x2 y0 y1 y2" - the camera's LOCAL +X and +Y + axes expressed in world frame. Camera looks down its local -Z. + + Convention here: + forward (cam -Z) = normalize(target - position) + right (cam +X) = normalize(cross(forward, up)) + down (cam -Y) = normalize(cross(right, forward)) + -> cam +Y = -down (i.e. "image up" points toward world up) + + Returns None on a degenerate case (target == position, or colinear up). + Callers should surface a clear error in that case rather than silently + emitting the default orientation. + """ + import math + + fx, fy, fz = target[0] - position[0], target[1] - position[1], target[2] - position[2] + flen = math.sqrt(fx * fx + fy * fy + fz * fz) + if flen < 1e-9: + return None + fx, fy, fz = fx / flen, fy / flen, fz / flen + + ux, uy, uz = up + # right = forward × up + rx = fy * uz - fz * uy + ry = fz * ux - fx * uz + rz = fx * uy - fy * ux + rlen = math.sqrt(rx * rx + ry * ry + rz * rz) + if rlen < 1e-9: + # forward is parallel to up - fall back to world-X as right. + rx, ry, rz = 1.0, 0.0, 0.0 + rlen = 1.0 + rx, ry, rz = rx / rlen, ry / rlen, rz / rlen + + # image-up = right × forward (so the Y axis points away from world-down) + iy_x = ry * fz - rz * fy + iy_y = rz * fx - rx * fz + iy_z = rx * fy - ry * fx + + return f"{rx:.6f} {ry:.6f} {rz:.6f} {iy_x:.6f} {iy_y:.6f} {iy_z:.6f}" + + +class MJCFBuilder: + """Builds MuJoCo MJCF XML from SimWorld state.""" + + @staticmethod + def build_objects_only(world: SimWorld) -> str: + """Build MJCF XML for a world with only objects (robots loaded separately).""" + _ensure_mujoco() + + parts = [] + parts.append('') + parts.append(' ') + + gx, gy, gz = world.gravity + parts.append(f' ") + + return "\n".join(parts) + + @staticmethod + def _object_xml(obj: SimObject, indent: int = 4) -> str: + """Generate MJCF XML for a single object.""" + pad = " " * indent + px, py, pz = obj.position + qw, qx, qy, qz = obj.orientation + r, g, b, a = obj.color + lines = [] + + lines.append(f'{pad}') + + if not obj.is_static: + lines.append(f'{pad} ') + lines.append(f'{pad} ') + + if obj.shape == "box": + sx, sy, sz = [s / 2 for s in obj.size] + lines.append( + f'{pad} ' + ) + elif obj.shape == "sphere": + radius = obj.size[0] / 2 if obj.size else 0.025 + lines.append( + f'{pad} ' + ) + elif obj.shape == "cylinder": + radius = obj.size[0] / 2 if obj.size else 0.025 + half_h = obj.size[2] / 2 if len(obj.size) > 2 else 0.05 + lines.append( + f'{pad} ' + ) + elif obj.shape == "capsule": + radius = obj.size[0] / 2 if obj.size else 0.025 + half_h = obj.size[2] / 2 if len(obj.size) > 2 else 0.05 + lines.append( + f'{pad} ' + ) + elif obj.shape == "mesh" and obj.mesh_path: + lines.append( + f'{pad} ' + ) + elif obj.shape == "plane": + sx = obj.size[0] if obj.size else 1.0 + sy = obj.size[1] if len(obj.size) > 1 else sx + lines.append( + f'{pad} ' + ) + + lines.append(f"{pad}") + return "\n".join(lines) + + @staticmethod + def compose_multi_robot_scene( + robots: dict[str, SimRobot], + objects: dict[str, SimObject], + cameras: dict[str, SimCamera], + world: SimWorld, + ) -> str: + """Compose a multi-robot scene by merging URDF-derived MJCF fragments.""" + mj = _ensure_mujoco() + world._backend_state["tmpdir"] = tempfile.TemporaryDirectory(prefix="strands_sim_") + tmpdir = world._backend_state["tmpdir"].name + + robot_xmls = {} + for robot_name, robot in robots.items(): + try: + model = mj.MjModel.from_xml_path(str(robot.urdf_path)) + robot_xml_path = os.path.join(tmpdir, f"{robot_name}.xml") + mj.mj_saveLastXML(robot_xml_path, model) + robot_xmls[robot_name] = robot_xml_path + logger.debug("Converted %s → %s", robot.urdf_path, robot_xml_path) + except (FileNotFoundError, OSError, subprocess.CalledProcessError) as e: + logger.error("Failed to convert URDF for '%s': %s", robot_name, e) + raise + + parts = [] + parts.append('') + parts.append(' ') + + gx, gy, gz = world.gravity + parts.append(f' ") + + master_xml = "\n".join(parts) + master_path = os.path.join(tmpdir, "master_scene.xml") + with open(master_path, "w") as f: + f.write(master_xml) + + return master_path diff --git a/strands_robots/simulation/mujoco/physics.py b/strands_robots/simulation/mujoco/physics.py new file mode 100644 index 0000000..b1ad4f8 --- /dev/null +++ b/strands_robots/simulation/mujoco/physics.py @@ -0,0 +1,1158 @@ +"""Physics mixin - advanced MuJoCo physics introspection and manipulation. + +Exposes the deep MuJoCo C API through clean Python methods: +- Raycasting (mj_ray) +- Jacobians (mj_jacBody, mj_jacSite, mj_jacGeom) +- Energy computation (mj_energyPos, mj_energyVel) +- External forces (mj_applyFT, xfrc_applied) +- Mass matrix (mj_fullM) +- State checkpointing (mj_getState, mj_setState) +- Inverse dynamics (mj_inverse) +- Body/joint introspection (poses, velocities, accelerations) +- Direct joint position/velocity control (qpos, qvel) +- Runtime model modification (mass, friction, color, size) +- Sensor readout (sensordata) +- Contact force analysis (mj_contactForce) +""" + +import logging +from typing import TYPE_CHECKING, Any + +import numpy as np + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class PhysicsMixin: + """Advanced MuJoCo physics capabilities mixed into ``Simulation``. + + Lives at roughly ``self._world._data`` + ``self._world._model`` level: + reads/writes MuJoCo arrays directly for checkpointing, raycasts, + jacobians, joint control, sensor readout, etc. + + **Coupling** (see simulation.py top-level docstring): mixin reaches + into ``self._world``, ``self._lock``, and the host's + ``_require_no_running_policy`` / ``_require_world`` / ``_prune_done_futures`` + helpers. ``TYPE_CHECKING`` stubs below exist so mypy accepts those + lookups; they are a documentary contract, not an enforceable protocol. + + Naming: methods match action names in tool_spec.json for direct dispatch. + """ + + if TYPE_CHECKING: + import threading + + from strands_robots.simulation.models import SimWorld + + _lock: "threading.Lock" + _world: "SimWorld | None" + + def _require_no_running_policy( + self, action_name: str, robot_name: str | None = None + ) -> dict[str, Any] | None: ... + def _require_world(self) -> dict[str, Any] | None: ... + + # State Checkpointing + + def save_state(self, name: str = "default") -> dict[str, Any]: + """Save the full physics state (qpos, qvel, act, time) to a named checkpoint. + + Uses mj_getState with mjSTATE_PHYSICS for complete state capture. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + with self._lock: + state_size = mj.mj_stateSize(model, mj.mjtState.mjSTATE_PHYSICS) + state = np.zeros(state_size) + mj.mj_getState(model, data, state, mj.mjtState.mjSTATE_PHYSICS) + + if not hasattr(self._world, "_checkpoints"): + self._world._checkpoints = {} + + self._world._checkpoints[name] = { + "state": state.copy(), + "sim_time": self._world.sim_time, + "step_count": self._world.step_count, + } + + return { + "status": "success", + "content": [ + { + "text": ( + f"💾 State '{name}' saved\n" + f" t={self._world.sim_time:.4f}s, step={self._world.step_count}\n" + f"State vector: {state_size} floats\n" + f"Checkpoints: {list(self._world._checkpoints.keys())}" + ) + } + ], + } + + def load_state(self, name: str = "default") -> dict[str, Any]: + """Restore physics state from a named checkpoint.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # load_state during a running policy races worker thread + if err := self._require_no_running_policy("load_state"): + return err + + checkpoints = getattr(self._world, "_checkpoints", {}) + if name not in checkpoints: + available = list(checkpoints.keys()) if checkpoints else ["none"] + return { + "status": "error", + "content": [{"text": f"Checkpoint '{name}' not found. Available: {available}"}], + } + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + checkpoint = checkpoints[name] + + with self._lock: + mj.mj_setState(model, data, checkpoint["state"], mj.mjtState.mjSTATE_PHYSICS) + mj.mj_forward(model, data) + + self._world.sim_time = checkpoint["sim_time"] + self._world.step_count = checkpoint["step_count"] + + return { + "status": "success", + "content": [ + {"text": f"📂 State '{name}' restored (t={self._world.sim_time:.4f}s, step={self._world.step_count})"} + ], + } + + # External Forces + + def apply_force( + self, + body_name: str, + force: list[float] | None = None, + torque: list[float] | None = None, + point: list[float] | None = None, + ) -> dict[str, Any]: + """Apply an external force and/or torque to a body (latched). + + Uses mj_applyFT for precise force application at a world-frame point. + The force is latched in ``qfrc_applied`` and applied on every + subsequent ``mj_step`` until overwritten by the next ``apply_force`` + call. Each call zeroes the buffer first (replacing, not accumulating). + + To stop the force: ``apply_force(body, force=[0, 0, 0])``. + + Args: + body_name: Target body name. + force: [fx, fy, fz] in world frame (Newtons). + torque: [tx, ty, tz] in world frame (N·m). + point: [px, py, pz] world-frame point of force application. + Defaults to body CoM if not specified. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # apply_force during a running policy races worker thread + if err := self._require_no_running_policy("apply_force"): + return err + + # must supply at least one non-zero force or torque + if force is None and torque is None: + return { + "status": "error", + "content": [{"text": "apply_force: specify at least one of 'force' or 'torque' (non-zero vector)."}], + } + + # Validate vector lengths before hitting numpy + for _name, _vec in (("force", force), ("torque", torque), ("point", point)): + if _vec is not None: + try: + if len(_vec) != 3: + return { + "status": "error", + "content": [ + {"text": f"apply_force: '{_name}' must be a 3-element vector [x,y,z], got {len(_vec)}"} + ], + } + except TypeError: + return { + "status": "error", + "content": [{"text": f"apply_force: '{_name}' must be a list/tuple of 3 numbers"}], + } + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + body_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + + f = np.array(force or [0, 0, 0], dtype=np.float64) + t = np.array(torque or [0, 0, 0], dtype=np.float64) + # Note: explicit [0,0,0] is a valid "clear the latched force" command; we only + # reject the case where the caller forgot both args (handled above). + p = np.array(point, dtype=np.float64) if point else data.xipos[body_id].copy() + + # Zero the buffer first so calls are idempotent (replace, not accumulate). + # NOTE: MuJoCo does NOT reset qfrc_applied in mj_step - the force + # persists on every subsequent step until the next apply_force call. + with self._lock: + data.qfrc_applied[:] = 0.0 + mj.mj_applyFT(model, data, f, t, p, body_id, data.qfrc_applied) + + return { + "status": "success", + "content": [ + { + "text": ( + f"💨 Force applied to '{body_name}' (body {body_id})\n" + f"Force: {f.tolist()} N\n" + f"Torque: {t.tolist()} N·m\n" + f"Point: {p.tolist()}" + ) + } + ], + } + + # Raycasting + + def _resolve_mj_name(self, obj_type: int, name: str) -> int: + """Look up a MuJoCo name, tolerating robot namespacing. + + For physics/introspection methods that accept raw body/joint/site + names (``get_body_state("gripper")`` etc.), we try the name + verbatim first, then fall back to trying it prefixed with every + robot's namespace. This preserves the pre-namespacing UX for + single-robot scenes while still working in multi-robot scenes + when the name is unambiguous. + + In multi-robot scenes where multiple robots contain a body with + the same short name (e.g. two so101s each having ``gripper``), + the caller MUST pass the namespaced form (``arm0/gripper``) to + disambiguate. The fallback returns the first match it finds, + which is non-deterministic - this is a deliberate + "unambiguous or explicit" contract. + """ + import mujoco as _mj + + assert self._world is not None and self._world._model is not None + model = self._world._model + mid = _mj.mj_name2id(model, obj_type, name) + if mid >= 0: + return int(mid) + if "/" in name: # already namespaced, no point retrying + return -1 + for robot in self._world.robots.values(): + if robot.namespace: + mid = _mj.mj_name2id(model, obj_type, robot.namespace + name) + if mid >= 0: + return int(mid) + return -1 + + def raycast( + self, + origin: list[float], + direction: list[float], + exclude_body: int = -1, + include_static: bool = True, + ) -> dict[str, Any]: + """Cast a ray and find the first geom intersection. + + Uses mj_ray for precise distance sensing / obstacle detection. + + Args: + origin: [x, y, z] ray start point in world frame. + direction: [dx, dy, dz] ray direction (auto-normalized). + exclude_body: Body ID to exclude from intersection (-1 = none). + include_static: Whether to include static geoms. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + # validate vector shapes and reject zero-direction (mj_ray aborts the process on len=0) + try: + if len(origin) != 3: + return { + "status": "error", + "content": [{"text": f"raycast: 'origin' must be 3 elements [x,y,z], got {len(origin)}"}], + } + if len(direction) != 3: + return { + "status": "error", + "content": [{"text": f"raycast: 'direction' must be 3 elements [dx,dy,dz], got {len(direction)}"}], + } + except TypeError: + return { + "status": "error", + "content": [{"text": "raycast: 'origin' and 'direction' must be lists of 3 numbers"}], + } + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + pnt = np.array(origin, dtype=np.float64) + vec = np.array(direction, dtype=np.float64) + # Normalize direction + norm = np.linalg.norm(vec) + if norm < 1e-10: + return { + "status": "error", + "content": [{"text": "raycast: 'direction' vector is zero-length - supply a non-zero direction."}], + } + vec = vec / norm + + geomid = np.array([-1], dtype=np.int32) + dist = mj.mj_ray( + model, + data, + pnt, + vec, + None, # geom group filter (None = all) + 1 if include_static else 0, + exclude_body, + geomid, + ) + + hit = dist >= 0 + geom_name = None + if hit and geomid[0] >= 0: + geom_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, geomid[0]) + + result = { + "hit": hit, + "distance": float(dist) if hit else None, + "geom_id": int(geomid[0]) if hit else None, + "geom_name": geom_name, + "hit_point": (pnt + vec * dist).tolist() if hit else None, + } + + if hit: + text = f"🎯 Ray hit '{geom_name or geomid[0]}' at dist={dist:.4f}m, point={result['hit_point']}" + else: + text = "🎯 Ray: no intersection" + + return {"status": "success", "content": [{"text": text}, {"json": result}]} + + # Jacobians + + def get_jacobian( + self, + body_name: str | None = None, + site_name: str | None = None, + geom_name: str | None = None, + ) -> dict[str, Any]: + """Compute the Jacobian (position + rotation) for a body, site, or geom. + + The Jacobian maps joint velocities to Cartesian velocities: + v = J @ dq + + Returns both positional (3×nv) and rotational (3×nv) Jacobians. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + jacp = np.zeros((3, model.nv)) + jacr = np.zeros((3, model.nv)) + + if body_name: + obj_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_BODY, body_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + mj.mj_jacBody(model, data, jacp, jacr, obj_id) + label = f"body '{body_name}'" + elif site_name: + obj_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_SITE, site_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"Site '{site_name}' not found."}]} + mj.mj_jacSite(model, data, jacp, jacr, obj_id) + label = f"site '{site_name}'" + elif geom_name: + obj_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_GEOM, geom_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"Geom '{geom_name}' not found."}]} + mj.mj_jacGeom(model, data, jacp, jacr, obj_id) + label = f"geom '{geom_name}'" + else: + return {"status": "error", "content": [{"text": "Specify body_name, site_name, or geom_name."}]} + + return { + "status": "success", + "content": [ + {"text": f"🧮 Jacobian for {label}: pos={jacp.shape}, rot={jacr.shape}, nv={model.nv}"}, + {"json": {"jacp": jacp.tolist(), "jacr": jacr.tolist(), "nv": model.nv}}, + ], + } + + # Energy + + def get_energy(self) -> dict[str, Any]: + """Compute potential and kinetic energy of the system.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_energyPos(model, data) + mj.mj_energyVel(model, data) + + potential = float(data.energy[0]) + kinetic = float(data.energy[1]) + total = potential + kinetic + + return { + "status": "success", + "content": [ + {"text": f"⚡ Energy: potential={potential:.4f}J, kinetic={kinetic:.4f}J, total={total:.4f}J"}, + {"json": {"potential": potential, "kinetic": kinetic, "total": total}}, + ], + } + + # Mass Matrix + + def get_mass_matrix(self) -> dict[str, Any]: + """Compute the full mass (inertia) matrix M(q). + + M is nv×nv where nv is the number of DoFs. + Useful for dynamics analysis, impedance control, etc. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + # data.qM is only valid after a forward pass; running mj_forward + # ensures the mass matrix reflects the current qpos (e.g. right after + # a reset/load_state). + mj.mj_forward(model, data) + nv = model.nv + M = np.zeros((nv, nv)) + if nv > 0: + mj.mj_fullM(model, M, data.qM) + rank = int(np.linalg.matrix_rank(M)) + cond = float(np.linalg.cond(M)) if rank > 0 else float("inf") + else: + # Empty scene (no DOFs yet) - return a well-typed zero payload + # instead of crashing in numpy on the empty matrix. + rank = 0 + cond = float("inf") + + return { + "status": "success", + "content": [ + {"text": f"🧮 Mass matrix: {nv}×{nv}, rank={rank}, cond={cond:.2e}"}, + { + "json": { + "shape": [nv, nv], + "rank": rank, + "condition_number": cond, + "diagonal": np.diag(M).tolist(), + "total_mass": float(np.sum(model.body_mass)), + } + }, + ], + } + + # Inverse Dynamics + + def inverse_dynamics(self) -> dict[str, Any]: + """Compute inverse dynamics: given qacc, what forces are needed? + + Runs mj_inverse to compute qfrc_inverse - the generalized forces + that would produce the current accelerations. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_inverse(model, data) + + # Build named force mapping + forces = {} + for i in range(model.njnt): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) + if name: + dof_adr = model.jnt_dofadr[i] + forces[name] = float(data.qfrc_inverse[dof_adr]) + + return { + "status": "success", + "content": [ + {"text": f"🔄 Inverse dynamics: {len(forces)} joint forces computed"}, + {"json": {"qfrc_inverse": forces}}, + ], + } + + # Body Introspection + + def get_body_state( + self, + body_name: str, + ) -> dict[str, Any]: + """Get the full state of a body: position, orientation, velocity, acceleration. + + Returns Cartesian pose + 6D spatial velocity (linear + angular). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + body_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + + # Position and orientation + pos = data.xpos[body_id].tolist() + quat = data.xquat[body_id].tolist() + rotmat = data.xmat[body_id].reshape(3, 3).tolist() + + # Velocity (6D: angular then linear in world frame) + vel = np.zeros(6) + mj.mj_objectVelocity(model, data, mj.mjtObj.mjOBJ_BODY, body_id, vel, 0) + linvel = vel[3:].tolist() + angvel = vel[:3].tolist() + + # Mass and inertia + mass = float(model.body_mass[body_id]) + com = data.xipos[body_id].tolist() + + state = { + "position": pos, + "quaternion": quat, + "rotation_matrix": rotmat, + "linear_velocity": linvel, + "angular_velocity": angvel, + "mass": mass, + "center_of_mass": com, + } + + text = ( + f"🏷️ Body '{body_name}' (id={body_id}):\n" + f" pos: [{pos[0]:.4f}, {pos[1]:.4f}, {pos[2]:.4f}]\n" + f" quat: [{quat[0]:.4f}, {quat[1]:.4f}, {quat[2]:.4f}, {quat[3]:.4f}]\n" + f" linvel: [{linvel[0]:.4f}, {linvel[1]:.4f}, {linvel[2]:.4f}]\n" + f" angvel: [{angvel[0]:.4f}, {angvel[1]:.4f}, {angvel[2]:.4f}]\n" + f" mass: {mass:.4f}kg, com: {com}" + ) + + return {"status": "success", "content": [{"text": text}, {"json": state}]} + + # Direct Joint Control + + def set_joint_positions( + self, + positions: dict[str, float] | list[float] | None = None, + robot_name: str | None = None, + ) -> dict[str, Any]: + """Set joint positions directly (bypassing actuators). + + Writes to qpos and runs mj_forward to update kinematics. + Useful for teleportation, IK solutions, or keyframe setting. + + Accepts EITHER form: + + * dict: {joint_name: value, ...} - explicit per-joint, safest in multi-robot scenes. + * list/tuple: [v0, v1, ...] - ordered positional. Must match a single robot's + joint count (when ``robot_name`` is given, that robot's joints; otherwise the + world must contain exactly one robot, or the call errors). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # mutating qpos under a running policy races mj_step + if err := self._require_no_running_policy("set_joint_positions"): + return err + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if positions is None: + return { + "status": "error", + "content": [{"text": "set_joint_positions: 'positions' is required (list or dict of joint values)."}], + } + + # normalize list input to dict using a deterministic joint ordering + ignored: list[str] = [] + if isinstance(positions, (list, tuple)): + robots = list(self._world.robots.values()) + if robot_name is not None: + robots = [r for r in robots if r.name == robot_name] + if not robots: + return {"status": "error", "content": [{"text": f"Robot '{robot_name}' not found."}]} + if len(robots) == 0: + return { + "status": "error", + "content": [ + { + "text": "set_joint_positions: list form requires a robot in the world; pass a dict instead, or add a robot first." + } + ], + } + if len(robots) > 1 and robot_name is None: + return { + "status": "error", + "content": [ + { + "text": f"set_joint_positions: list form is ambiguous with {len(robots)} robots; pass 'robot_name=' or use a dict." + } + ], + } + robot = robots[0] + joint_names = list(getattr(robot, "joint_names", []) or []) + if not joint_names: + # Fall back: enumerate joints that belong to this robot via namespace + ns = getattr(robot, "namespace", "") or "" + joint_names = [] + for jid in range(model.njnt): + jn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, jid) + if jn and (not ns or jn.startswith(ns)): + joint_names.append(jn) + if len(positions) != len(joint_names): + return { + "status": "error", + "content": [ + { + "text": ( + f"set_joint_positions: list length {len(positions)} does not match robot " + f"'{robot.name}' joint count {len(joint_names)}. Use a dict for partial updates." + ) + } + ], + } + positions = dict(zip(joint_names, positions, strict=True)) + elif not isinstance(positions, dict): + return { + "status": "error", + "content": [ + {"text": f"set_joint_positions: 'positions' must be a dict or list, got {type(positions).__name__}"} + ], + } + + set_count = 0 + with self._lock: + for jnt_name, value in positions.items(): + jnt_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + qpos_adr = model.jnt_qposadr[jnt_id] + data.qpos[qpos_adr] = float(value) + set_count += 1 + else: + ignored.append(jnt_name) + logger.warning("Joint '%s' not found, skipping", jnt_name) + + mj.mj_forward(model, data) + + msg = f"🎯 Set {set_count}/{len(positions)} joint positions, FK updated" + if ignored: + msg += f" (ignored: {ignored})" + return { + "status": "success", + "content": [{"text": msg}], + } + + def set_joint_velocities( + self, + velocities: dict[str, float] | list[float] | None = None, + robot_name: str | None = None, + ) -> dict[str, Any]: + """Set joint velocities directly. + + Writes to qvel. Useful for initializing dynamics. Accepts dict or list + (see set_joint_positions for list semantics). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if err := self._require_no_running_policy("set_joint_velocities"): + return err + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if velocities is None: + return { + "status": "error", + "content": [{"text": "set_joint_velocities: 'velocities' is required (list or dict)."}], + } + + ignored: list[str] = [] + if isinstance(velocities, (list, tuple)): + robots = list(self._world.robots.values()) + if robot_name is not None: + robots = [r for r in robots if r.name == robot_name] + if not robots: + return {"status": "error", "content": [{"text": f"Robot '{robot_name}' not found."}]} + if len(robots) == 0: + return { + "status": "error", + "content": [{"text": "set_joint_velocities: list form requires a robot in the world."}], + } + if len(robots) > 1 and robot_name is None: + return { + "status": "error", + "content": [ + { + "text": f"set_joint_velocities: list form is ambiguous with {len(robots)} robots; pass 'robot_name=' or use a dict." + } + ], + } + robot = robots[0] + joint_names = list(getattr(robot, "joint_names", []) or []) + if not joint_names: + ns = getattr(robot, "namespace", "") or "" + joint_names = [] + for jid in range(model.njnt): + jn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, jid) + if jn and (not ns or jn.startswith(ns)): + joint_names.append(jn) + if len(velocities) != len(joint_names): + return { + "status": "error", + "content": [ + { + "text": ( + f"set_joint_velocities: list length {len(velocities)} does not match robot " + f"'{robot.name}' joint count {len(joint_names)}. Use a dict for partial updates." + ) + } + ], + } + velocities = dict(zip(joint_names, velocities, strict=True)) + elif not isinstance(velocities, dict): + return { + "status": "error", + "content": [ + { + "text": f"set_joint_velocities: 'velocities' must be a dict or list, got {type(velocities).__name__}" + } + ], + } + + set_count = 0 + with self._lock: + for jnt_name, value in velocities.items(): + jnt_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + dof_adr = model.jnt_dofadr[jnt_id] + data.qvel[dof_adr] = float(value) + set_count += 1 + else: + ignored.append(jnt_name) + + msg = f"💨 Set {set_count}/{len(velocities)} joint velocities" + if ignored: + msg += f" (ignored: {ignored})" + return { + "status": "success", + "content": [{"text": msg}], + } + + # Sensor Readout + + def get_sensor_data(self, sensor_name: str | None = None) -> dict[str, Any]: + """Read sensor values from the simulation. + + MuJoCo supports: jointpos, jointvel, accelerometer, gyro, force, + torque, touch, rangefinder, framequat, subtreecom, clock, etc. + + Args: + sensor_name: Specific sensor name, or None for all sensors. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if model.nsensor == 0: + # distinguish "no sensors at all" from "that specific sensor not found" + if sensor_name: + return { + "status": "error", + "content": [{"text": f"Sensor '{sensor_name}' not found. Model has no sensors."}], + } + return {"status": "success", "content": [{"text": "📡 No sensors in model."}]} + + mj.mj_forward(model, data) + + sensors = {} + for i in range(model.nsensor): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_SENSOR, i) + if not name: + name = f"sensor_{i}" + + adr = model.sensor_adr[i] + dim = model.sensor_dim[i] + values = data.sensordata[adr : adr + dim].tolist() + + if sensor_name and name != sensor_name: + continue + + sensors[name] = { + "values": values if dim > 1 else values[0], + "dim": int(dim), + "type": int(model.sensor_type[i]), + } + + if sensor_name and sensor_name not in sensors: + return {"status": "error", "content": [{"text": f"Sensor '{sensor_name}' not found."}]} + + lines = [f"📡 Sensors ({len(sensors)}/{model.nsensor}):"] + for name, info in sensors.items(): + lines.append(f"{name}: {info['values']} (dim={info['dim']})") + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"json": {"sensors": sensors}}], + } + + # Runtime Model Modification + + def set_body_properties( + self, + body_name: str, + mass: float | None = None, + ) -> dict[str, Any]: + """Modify body properties at runtime (no recompile needed). + + Changes take effect on the next mj_step. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if err := self._require_no_running_policy("set_body_properties"): + return err + + # mass must be > 0 (physics invariant) + if mass is not None: + try: + mass = float(mass) + except (TypeError, ValueError): + return { + "status": "error", + "content": [{"text": f"set_body_properties: 'mass' must be a positive number, got {mass!r}"}], + } + if mass <= 0: + return { + "status": "error", + "content": [{"text": f"set_body_properties: 'mass' must be > 0, got {mass}"}], + } + + mj = _ensure_mujoco() + model = self._world._model + body_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + + changes = [] + with self._lock: + if mass is not None: + old_mass = float(model.body_mass[body_id]) + model.body_mass[body_id] = mass + changes.append(f"mass: {old_mass:.3f} → {mass:.3f}") + + return { + "status": "success", + "content": [{"text": f"🔧 Body '{body_name}': {', '.join(changes)}"}], + } + + def set_geom_properties( + self, + geom_name: str | None = None, + geom_id: int | None = None, + color: list[float] | None = None, + friction: list[float] | None = None, + size: list[float] | None = None, + ) -> dict[str, Any]: + """Modify geom properties at runtime (no recompile needed). + + Changes take effect immediately for rendering (color) or next step (friction, size). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if err := self._require_no_running_policy("set_geom_properties"): + return err + + mj = _ensure_mujoco() + model = self._world._model + + gid = geom_id + if geom_name: + gid = self._resolve_mj_name(mj.mjtObj.mjOBJ_GEOM, geom_name) + # our add_object pipeline names geoms as ``{object_name}_geom``. + # Accept the plain object name as a convenience alias. + if (gid is None or gid < 0) and not geom_name.endswith("_geom"): + gid = self._resolve_mj_name(mj.mjtObj.mjOBJ_GEOM, f"{geom_name}_geom") + if gid is None or gid < 0 or gid >= model.ngeom: + return {"status": "error", "content": [{"text": f"Geom '{geom_name or geom_id}' not found."}]} + + label = geom_name or f"geom_{gid}" + changes = [] + + with self._lock: + if color is not None: + model.geom_rgba[gid] = color[:4] if len(color) >= 4 else color[:3] + [1.0] + changes.append(f"color → {model.geom_rgba[gid].tolist()}") + + if friction is not None: + fric = friction[:3] if len(friction) >= 3 else friction + [0.0] * (3 - len(friction)) + model.geom_friction[gid] = fric + changes.append(f"friction → {fric}") + + if size is not None: + n = min(len(size), 3) + model.geom_size[gid, :n] = size[:n] + changes.append(f"size → {model.geom_size[gid].tolist()}") + + return { + "status": "success", + "content": [{"text": f"🔧 Geom '{label}': {', '.join(changes)}"}], + } + + # Contact Force Analysis + + def get_contact_forces(self) -> dict[str, Any]: + """Get detailed contact forces for all active contacts. + + Uses mj_contactForce for each active contact pair. + Returns normal and friction forces. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + contacts = [] + for i in range(data.ncon): + c = data.contact[i] + g1 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom1) or f"geom_{c.geom1}" + g2 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom2) or f"geom_{c.geom2}" + + # Get contact force (normal + friction in contact frame) + force = np.zeros(6) + mj.mj_contactForce(model, data, i, force) + + contacts.append( + { + "geom1": g1, + "geom2": g2, + "distance": float(c.dist), + "position": c.pos.tolist(), + "normal_force": float(force[0]), + "friction_force": force[1:3].tolist(), + "full_wrench": force.tolist(), + } + ) + + if not contacts: + return {"status": "success", "content": [{"text": "💥 No active contacts."}]} + + lines = [f"💥 {len(contacts)} contacts:"] + for c in contacts[:15]: + lines.append(f"{c['geom1']} ↔ {c['geom2']}: normal={c['normal_force']:.3f}N, dist={c['distance']:.4f}m") + if len(contacts) > 15: + lines.append(f" ... and {len(contacts) - 15} more") + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"json": {"contacts": contacts}}], + } + + # Multi-Ray (batch raycasting) + + def multi_raycast( + self, + origin: list[float], + directions: list[list[float]], + exclude_body: int = -1, + ) -> dict[str, Any]: + """Cast multiple rays from a single origin (e.g., for LIDAR simulation). + + Efficiently casts N rays using individual mj_ray calls. + Returns array of distances and hit geoms. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + # validate origin shape; per-ray zero-direction guard (avoid mj_ray abort) + try: + if len(origin) != 3: + return { + "status": "error", + "content": [{"text": f"multi_raycast: 'origin' must be 3 elements [x,y,z], got {len(origin)}"}], + } + except TypeError: + return {"status": "error", "content": [{"text": "multi_raycast: 'origin' must be a list of 3 numbers"}]} + + pnt = np.array(origin, dtype=np.float64) + results: list[dict[str, Any]] = [] + + for idx, d in enumerate(directions): + try: + if len(d) != 3: + results.append( + { + "distance": None, + "geom_id": None, + "error": f"ray[{idx}]: direction must have 3 elements, got {len(d)}", + } + ) + continue + except TypeError: + results.append( + {"distance": None, "geom_id": None, "error": f"ray[{idx}]: direction must be a list of 3 numbers"} + ) + continue + vec = np.array(d, dtype=np.float64) + norm = np.linalg.norm(vec) + if norm < 1e-10: + results.append({"distance": None, "geom_id": None, "error": f"ray[{idx}]: zero-length direction"}) + continue + vec /= norm + geomid = np.array([-1], dtype=np.int32) + dist = mj.mj_ray(model, data, pnt, vec, None, 1, exclude_body, geomid) + results.append( + { + "distance": float(dist) if dist >= 0 else None, + "geom_id": int(geomid[0]) if dist >= 0 else None, + } + ) + + hit_count = sum(1 for r in results if r["distance"] is not None) + return { + "status": "success", + "content": [ + {"text": f"🎯 Multi-ray: {hit_count}/{len(directions)} hits from {origin}"}, + {"json": {"rays": results}}, + ], + } + + # Forward Kinematics (explicit) + + def forward_kinematics(self, body_name: str | None = None) -> dict[str, Any]: + """Run forward kinematics to update all body positions/orientations. + + Usually called implicitly by mj_step, but useful after manually + setting qpos to see updated Cartesian positions. + + If ``body_name`` is given, the response is filtered to that + single body (and errors cleanly if the body doesn't exist). + Otherwise returns every body as before. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_kinematics(model, data) + mj.mj_comPos(model, data) + mj.mj_camlight(model, data) + + if body_name is not None: + bid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if bid < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + body_payload = { + "position": data.xpos[bid].tolist(), + "quaternion": data.xquat[bid].tolist(), + } + return { + "status": "success", + "content": [ + {"text": f"🦴 FK for '{body_name}': pos={body_payload['position']}"}, + {"json": {"body": body_name, **body_payload}}, + ], + } + + bodies = {} + for i in range(model.nbody): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + bodies[name] = { + "position": data.xpos[i].tolist(), + "quaternion": data.xquat[i].tolist(), + } + + return { + "status": "success", + "content": [ + {"text": f"🦴 FK computed for {model.nbody} bodies"}, + {"json": {"bodies": bodies}}, + ], + } + + # Total Mass + + def get_total_mass(self) -> dict[str, Any]: + """Get total mass and per-body mass breakdown.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model = self._world._model + + total = float(mj.mj_getTotalmass(model)) + bodies = {} + for i in range(model.nbody): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + m = float(model.body_mass[i]) + if m > 0: + bodies[name] = m + + return { + "status": "success", + "content": [ + {"text": f"⚖️ Total mass: {total:.4f}kg ({len(bodies)} bodies with mass)"}, + {"json": {"total_mass": total, "bodies": bodies}}, + ], + } + + # Export Model XML + + def export_xml(self, output_path: str | None = None) -> dict[str, Any]: + """Export the current model to MJCF XML. + + Uses mj_saveLastXML - exports the exact model currently loaded, + including any runtime modifications. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + + if output_path: + mj.mj_saveLastXML(output_path, self._world._model) + return {"status": "success", "content": [{"text": f"📄 Model exported to {output_path}"}]} + else: + # Return XML string via saveLastXML to temp file + import os + import tempfile + + with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as tmp: + tmpfile = tmp.name + mj.mj_saveLastXML(tmpfile, self._world._model) + try: + with open(tmpfile) as f: + xml = f.read() + finally: + os.unlink(tmpfile) + return { + "status": "success", + "content": [ + {"text": f"📄 Model XML ({len(xml)} chars):\n{xml[:2000]}{'...' if len(xml) > 2000 else ''}"} + ], + } diff --git a/strands_robots/simulation/mujoco/randomization.py b/strands_robots/simulation/mujoco/randomization.py new file mode 100644 index 0000000..923e071 --- /dev/null +++ b/strands_robots/simulation/mujoco/randomization.py @@ -0,0 +1,137 @@ +"""Domain randomization mixin.""" + +import logging +from typing import TYPE_CHECKING, Any + +import numpy as np + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RandomizationMixin: + """Domain randomization mixed into ``Simulation``. + + Recolors geoms, perturbs lighting, and scales body mass / geom friction + by a random factor inside a user-supplied range. + + **Coupling** (see simulation.py top-level docstring): mixin reaches + into ``self._world``, ``self._lock``, and the host's + ``_require_no_running_policy`` / ``_require_world`` helpers. ``TYPE_CHECKING`` + stubs below exist so mypy accepts those lookups; they are a + documentary contract, not an enforceable protocol. + """ + + if TYPE_CHECKING: + import threading + + from strands_robots.simulation.models import SimWorld + + _lock: "threading.Lock" + _world: "SimWorld | None" + + def _require_no_running_policy( + self, action_name: str, robot_name: str | None = None + ) -> dict[str, Any] | None: ... + def _require_world(self) -> dict[str, Any] | None: ... + + def randomize( + self, + randomize_colors: bool = True, + randomize_lighting: bool = True, + randomize_physics: bool = False, + randomize_positions: bool = False, + position_noise: float = 0.02, + color_range: tuple[float, float] = (0.1, 1.0), + friction_range: tuple[float, float] = (0.5, 1.5), + mass_range: tuple[float, float] = (0.5, 2.0), + seed: int | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Apply domain randomization to the scene. + + Each flag is opt-in per-axis. Defaults: + - ``randomize_colors=True`` - geom RGB re-sampled in ``color_range``. + - ``randomize_lighting=True`` - light pos jittered ±0.5m, diffuse resampled. + - ``randomize_physics=False`` - friction/mass left untouched unless asked. + - ``randomize_positions=False`` - object qpos left untouched unless asked. + + "No flags" means "nothing is randomized" - the call is a no-op. This + matches the LLM ergonomics principle: explicit is better than implicit. + Randomization IS destructive (writes to ``model.geom_*`` / ``body_*`` + arrays and to ``data.qpos``); recompile the scene to undo. + + Args: + randomize_colors: Re-sample geom RGB values. + randomize_lighting: Jitter light positions + diffuse colour. + randomize_physics: Scale geom friction and body mass. + randomize_positions: Add uniform noise to dynamic-object xyz. + position_noise: Max ± xyz offset in meters when randomising positions. + color_range: (lo, hi) for uniform RGB sampling. + friction_range: (lo, hi) multiplicative scale on friction[0]. + mass_range: (lo, hi) multiplicative scale on body_mass. + seed: Optional np.random seed for reproducibility. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # domain randomization mutates model arrays; a running policy racing with it is UB + if err := self._require_no_running_policy("randomize"): + return err + + rng = np.random.default_rng(seed) + mj = _ensure_mujoco() + model = self._world._model + data = self._world._data + changes = [] + + with self._lock: + if randomize_colors: + for i in range(model.ngeom): + geom_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, i) + if geom_name and geom_name != "ground": + model.geom_rgba[i, :3] = rng.uniform(color_range[0], color_range[1], size=3) + changes.append(f"🎨 Colors: {model.ngeom} geoms randomized") + + if randomize_lighting: + for i in range(model.nlight): + model.light_pos[i] += rng.uniform(-0.5, 0.5, size=3) + model.light_diffuse[i] = rng.uniform(0.3, 1.0, size=3) + changes.append(f"💡 Lighting: {model.nlight} lights randomized") + + if randomize_physics: + friction_scales = {} + for i in range(model.ngeom): + gn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, i) or f"geom_{i}" + f = float(rng.uniform(*friction_range)) + model.geom_friction[i, 0] *= f + friction_scales[gn] = f + mass_scales = {} + for i in range(model.nbody): + if model.body_mass[i] > 0: + bn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + s = float(rng.uniform(*mass_range)) + model.body_mass[i] *= s + mass_scales[bn] = s + changes.append( + f"⚙️ Physics: {len(friction_scales)} geoms friction-scaled, {len(mass_scales)} bodies mass-scaled" + ) + changes.append(f" friction_scales={friction_scales}") + changes.append(f" mass_scales={mass_scales}") + + if randomize_positions: + for obj_name, obj in self._world.objects.items(): + if not obj.is_static: + jnt_name = f"{obj_name}_joint" + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + qpos_addr = model.jnt_qposadr[jnt_id] + noise = rng.uniform(-position_noise, position_noise, size=3) + data.qpos[qpos_addr : qpos_addr + 3] += noise + mj.mj_forward(model, data) + changes.append(f"📍 Positions: ±{position_noise}m noise on dynamic objects") + + return { + "status": "success", + "content": [{"text": "🎲 Domain Randomization applied:\n" + "\n".join(changes)}], + } diff --git a/strands_robots/simulation/mujoco/recording.py b/strands_robots/simulation/mujoco/recording.py new file mode 100644 index 0000000..0ba6c01 --- /dev/null +++ b/strands_robots/simulation/mujoco/recording.py @@ -0,0 +1,218 @@ +"""Recording mixin - start/stop trajectory recording to LeRobotDataset.""" + +import logging +import shutil +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RecordingMixin: + """Trajectory recording mixed into ``Simulation``. + + Writes per-step observations + actions + instruction to a LeRobotDataset + via ``start_recording`` / ``stop_recording`` and the ``on_frame`` hook + in ``PolicyRunner``. Separately from that, ``start_cameras_recording`` + dumps raw per-camera MP4s. + + **Coupling** (see simulation.py top-level docstring): mixin reaches + into ``self._world`` (trajectory buffer + dataset_recorder live in + ``_world._backend_state``). ``TYPE_CHECKING`` stub below exists so mypy + accepts the ``_world`` lookup; it is a documentary contract, not an + enforceable protocol. + """ + + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + + def start_recording( + self, + repo_id: str = "local/sim_recording", + task: str = "", + fps: int = 30, + root: str | None = None, + push_to_hub: bool = False, + vcodec: str = "libsvtav1", + overwrite: bool = False, + ) -> dict[str, Any]: + """Start recording to LeRobotDataset format (parquet + per-camera MP4). + + Requires the ``lerobot`` extra for the dataset schema. If you only + need plain MP4 video (no dataset schema, no policy-training metadata), + use :meth:`start_cameras_recording` - it runs under the + ``[sim-mujoco]`` extra alone (imageio-ffmpeg backend). + + Raises: + Friendly error when ``lerobot`` is not installed, directing the + caller to :meth:`start_cameras_recording` or to install the + optional extra. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + _DatasetRecorder: Any = None + _has_lerobot = False + try: + from strands_robots.dataset_recorder import DatasetRecorder as _DatasetRecorder + from strands_robots.dataset_recorder import has_lerobot_dataset as _check_lerobot + + _has_lerobot = _check_lerobot() + except ImportError: + pass + + if not _has_lerobot or _DatasetRecorder is None: + return { + "status": "error", + "content": [ + { + "text": ( + "start_recording produces a LeRobotDataset (parquet + video) and " + "requires the lerobot extra. For plain MP4 video under the " + "[sim-mujoco] extra alone, use start_cameras_recording instead.\n" + "\n" + " - Dataset + policy training data: pip install 'strands-robots[lerobot]'\n" + " - Plain MP4 only: start_cameras_recording(cameras=..., output_dir=...)" + ) + } + ], + } + + self._world._backend_state["recording"] = True + self._world._backend_state["trajectory"] = [] + self._world._backend_state["push_to_hub"] = push_to_hub + + try: + if overwrite: + if root: + dataset_dir = Path(root) + elif "/" not in repo_id or repo_id.startswith("/") or repo_id.startswith("./"): + dataset_dir = Path(repo_id) + else: + dataset_dir = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id + if dataset_dir.exists() and dataset_dir.is_dir(): + shutil.rmtree(dataset_dir) + logger.info("Removed existing dataset dir: %s", dataset_dir) + + # Collect joint names from every robot. When the scene contains + # more than one robot (e.g. multi-agent dual-task recording), prefix + # each joint with the robot's instance name (``alice__shoulder_pan``) + # so the dataset schema has unique joint ids per agent. Single-robot + # scenes keep the clean ``shoulder_pan`` names for backwards compat. + joint_names: list[str] = [] + camera_keys: list[str] = [] + robot_type = "unknown" + multi_robot = len(self._world.robots) > 1 + for rname, robot in self._world.robots.items(): + if multi_robot: + joint_names.extend(f"{rname}__{jn}" for jn in robot.joint_names) + else: + joint_names.extend(robot.joint_names) + robot_type = robot.data_config or rname + + mj = _ensure_mujoco() + for i in range(self._world._model.ncam): + cam_name = mj.mj_id2name(self._world._model, mj.mjtObj.mjOBJ_CAMERA, i) + if not cam_name: + continue + # LeRobot feature names can't contain '/' (reserved for + # nested-feature addressing). When a robot injects a + # namespaced camera (e.g. ``arm0/wrist_cam``), collapse + # the separator to ``__`` for the dataset schema. + safe_name = cam_name.replace("/", "__") + camera_keys.append(safe_name) + + assert _DatasetRecorder is not None # checked above + self._world._backend_state["dataset_recorder"] = _DatasetRecorder.create( + repo_id=repo_id, + fps=fps, + robot_type=robot_type, + joint_names=joint_names, + camera_keys=camera_keys, + task=task, + root=root, + vcodec=vcodec, + ) + return { + "status": "success", + "content": [ + { + "text": ( + f"Recording to LeRobotDataset: {repo_id}\n" + f"{len(joint_names)} joints, {len(camera_keys)} cameras @ {fps}fps\n" + f"Codec: {vcodec} | Task: {task or '(set per policy)'}\n" + f"Run policies to capture frames, then stop_recording to save episode" + ) + } + ], + } + except Exception as e: + self._world._backend_state["recording"] = False + logger.error("Dataset recorder init failed: %s", e) + return {"status": "error", "content": [{"text": f"Dataset init failed: {e}"}]} + + def stop_recording(self, output_path: str | None = None) -> dict[str, Any]: + """Stop recording and save episode to LeRobotDataset. + + idempotent - calling when not recording succeeds with a + 'Was not recording' message so callers can safely call it unconditionally. + """ + if self._world is None or not self._world._backend_state.get("recording", False): + return {"status": "success", "content": [{"text": "Was not recording."}]} + + self._world._backend_state["recording"] = False + recorder = self._world._backend_state.get("dataset_recorder", None) + + if recorder is None: + return {"status": "error", "content": [{"text": "No dataset recorder active."}]} + + recorder.save_episode() + push_result = None + if self._world._backend_state.get("push_to_hub", False): + push_result = recorder.push_to_hub(tags=["strands-robots", "sim"]) + + repo_id = recorder.repo_id + frame_count = recorder.frame_count + episode_count = recorder.episode_count + root = recorder.root + + recorder.finalize() + self._world._backend_state["dataset_recorder"] = None + self._world._backend_state["trajectory"] = [] + + text = ( + f"Episode saved to LeRobotDataset\n" + f"{repo_id} -- {frame_count} frames, {episode_count} episode(s)\n" + f"Local: {root}" + ) + if push_result and push_result.get("status") == "success": + text += "\nPushed to HuggingFace Hub" + + return {"status": "success", "content": [{"text": text}]} + + def get_recording_status(self) -> dict[str, Any]: + """Returns success in every lifecycle state (no world / not + recording / recording) with a distinguishing message so callers can + poll it unconditionally without try/except.""" + if self._world is None: + return { + "status": "success", + "content": [{"text": "⚪ No world - call create_world to start recording."}], + } + + recording = self._world._backend_state.get("recording", False) + steps = len(self._world._backend_state.get("trajectory", [])) + + if recording: + text = f"🔴 Recording: {steps} steps captured" + else: + text = f"⚪ Not recording (last episode: {steps} steps)" + + return { + "status": "success", + "content": [{"text": text}], + } diff --git a/strands_robots/simulation/mujoco/rendering.py b/strands_robots/simulation/mujoco/rendering.py new file mode 100644 index 0000000..4137275 --- /dev/null +++ b/strands_robots/simulation/mujoco/rendering.py @@ -0,0 +1,741 @@ +"""Rendering mixin - render, render_depth, get_contacts, observation helpers.""" + +import io +import logging +from typing import TYPE_CHECKING, Any + +from strands_robots.simulation.mujoco.backend import _can_render, _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RenderingMixin: + """Rendering + observation helpers mixed into ``Simulation``. + + Owns ``render``, ``render_depth``, ``render_all``, ``get_contacts``, and + the low-level ``_apply_sim_action`` (MuJoCo ``ctrl[]`` write + mj_step). + + **Coupling** (see simulation.py top-level docstring): mixin reaches + into ``self._world``, ``self._renderer_tls``, ``self._renderer_model``, + ``self.default_width`` / ``self.default_height``, ``self._lock`` and + ``self._viewer_handle``. ``TYPE_CHECKING`` stubs below exist so mypy + accepts those lookups; they are a documentary contract, not an + enforceable protocol. + + Thread-safety note: MuJoCo ``Renderer`` uses thread-local GL contexts + (CGL on macOS, GLX on Linux). A renderer created on thread A cannot be + reused from thread B - we keep one per-thread via ``_renderer_tls``. + """ + + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + + _renderer_model: Any + _renderer_tls: Any # threading.local() - per-thread renderer dict + default_width: int + default_height: int + + def _validate_render_dims(self, width: int, height: int) -> dict[str, Any] | None: + """reject non-positive render dims; convert MuJoCo's framebuffer + overflow to a plain-English message that tells the LLM the actual cap. + """ + if not isinstance(width, int) or not isinstance(height, int): + return { + "status": "error", + "content": [ + {"text": f"render: width/height must be int, got {type(width).__name__}/{type(height).__name__}."} + ], + } + if width <= 0 or height <= 0: + return { + "status": "error", + "content": [{"text": f"render: width and height must be > 0, got {width}x{height}."}], + } + if self._world is not None and self._world._model is not None: + max_w = int(getattr(self._world._model.vis.global_, "offwidth", 1280)) + max_h = int(getattr(self._world._model.vis.global_, "offheight", 960)) + if width > max_w or height > max_h: + return { + "status": "error", + "content": [ + { + "text": ( + f"render: requested {width}x{height} exceeds the offscreen " + f"framebuffer cap ({max_w}x{max_h}). Lower width/height or " + f"rebuild the model with a larger ." + ) + } + ], + } + return None + + def _get_renderer(self, width: int, height: int): + """Get a cached MuJoCo renderer, creating one only if needed. + + Returns None if rendering is unavailable (headless without EGL/OSMesa). + Callers must handle None return. + + Thread-safety: renderers are cached per-thread via ``threading.local`` + because ``mujoco.Renderer`` binds a GL context to the thread that + creates it (CGL on macOS, GLX on Linux). Sharing renderers across + threads would cause ``cgl.free()`` segfaults at cleanup time. + """ + if not _can_render(): + return None + mj = _ensure_mujoco() + assert self._world is not None # callers must check + + # Get or create per-thread renderer dict + renderers = getattr(self._renderer_tls, "renderers", None) + if renderers is None: + renderers = {} + self._renderer_tls.renderers = renderers + self._renderer_tls.model = None + + # Invalidate this thread's cache if model changed (e.g. after recompile) + if self._renderer_tls.model is not self._world._model: + renderers.clear() + self._renderer_tls.model = self._world._model + # Keep the per-instance marker for compatibility with any remaining + # read paths that checked self._renderer_model. + self._renderer_model = self._world._model + + key = (width, height) + if key not in renderers: + renderers[key] = mj.Renderer(self._world._model, height=height, width=width) + return renderers[key] + + def _get_sim_observation(self, robot_name: str, *, skip_images: bool = False) -> dict[str, Any]: + """Get observation from sim: joint state + cameras (unless skipped). + + Implements :meth:`SimEngine.get_observation`'s schema. + + Multi-robot note: when the injected robot XML was namespaced + (e.g. ``arm0/shoulder_pan`` in MuJoCo to allow multiple same-config + robots), we look up the prefixed MuJoCo name but return the short + name in the observation dict so the policy sees a stable, config-level + schema regardless of how many robots are in the scene. + """ + mj = _ensure_mujoco() + assert self._world is not None # callers must check + model, data = self._world._model, self._world._data + robot = self._world.robots[robot_name] + pfx = robot.namespace or "" + + obs = {} + for jnt_name in robot.joint_names: + # Try namespaced name first (multi-robot), fall back to raw. + lookup = pfx + jnt_name if pfx else jnt_name + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, lookup) + if jnt_id < 0 and pfx: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + obs[jnt_name] = float(data.qpos[model.jnt_qposadr[jnt_id]]) + + if skip_images: + return obs + + # Render every camera defined on the model plus any python-side cameras. + # Individual camera failures are logged but do not drop joint state. + cameras_to_render = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) for i in range(model.ncam)] + for pycam_name in self._world.cameras: + if pycam_name not in cameras_to_render: + cameras_to_render.append(pycam_name) + + for cname in cameras_to_render: + if not cname: + continue + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, cname) + cam_info = self._world.cameras.get(cname) + h = cam_info.height if cam_info else self.default_height + w = cam_info.width if cam_info else self.default_width + try: + renderer = self._get_renderer(w, h) + if renderer is None: + continue + if cam_id >= 0: + renderer.update_scene(data, camera=cam_id) + else: + renderer.update_scene(data) + obs[cname] = renderer.render().copy() + except (RuntimeError, ValueError) as e: + # Individual camera failure shouldn't stop joint state collection. + # Common cause: camera ID invalid after scene recompile. + logger.debug("Camera render failed for %s: %s", cname, e) + + return obs + + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: + """Apply action dict to sim (same interface as robot.send_action). + + Multi-robot note: action keys are *short* names (e.g. ``shoulder_pan``). + We look up the namespaced MuJoCo actuator/joint name for this + specific ``robot_name`` so the same action dict routes to the right + physical actuator when multiple same-config robots exist. + """ + mj = _ensure_mujoco() + assert self._world is not None # callers must check + model, data = self._world._model, self._world._data + robot = self._world.robots.get(robot_name) + pfx = robot.namespace if robot else "" + + def _lookup(obj_type: Any, name: str) -> int: + """Try namespaced lookup first, fall back to raw.""" + if pfx: + i = mj.mj_name2id(model, obj_type, pfx + name) + if i >= 0: + return i + return int(mj.mj_name2id(model, obj_type, name)) + + for key, value in action_dict.items(): + act_id = _lookup(mj.mjtObj.mjOBJ_ACTUATOR, key) + if act_id >= 0: + data.ctrl[act_id] = float(value) + else: + # Fallback: key is a joint name - find the actuator that + # drives this joint via actuator_trnid (joint ID → actuator). + jnt_id = _lookup(mj.mjtObj.mjOBJ_JOINT, key) + if jnt_id >= 0: + for ai in range(model.nu): + if model.actuator_trnid[ai, 0] == jnt_id: + data.ctrl[ai] = float(value) + break + + for _ in range(max(1, n_substeps)): + mj.mj_step(model, data) + + assert self._world is not None + self._world.sim_time = data.time + assert self._world is not None # callers must check + self._world.step_count += n_substeps + + if hasattr(self, "_viewer_handle") and self._viewer_handle is not None: + self._viewer_handle.sync() + + def render( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: + """Render a camera view as base64 PNG image.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + # treat `None` as "use default", but `0` / negative values must + # still hit the validator (bool coercion would swallow them silently). + w = self.default_width if width is None else width + h = self.default_height if height is None else height + if err := self._validate_render_dims(w, h): + return err + + try: + renderer = self._get_renderer(w, h) + if renderer is None: + return { + "status": "error", + "content": [ + { + "text": ( + " Rendering unavailable (no OpenGL context). " + "Install EGL or OSMesa for offscreen rendering: " + "apt-get install libosmesa6-dev" + ) + } + ], + } + # strict camera validation - no silent fallback to default. + # Special 'default' / 'free' tokens route to the free camera; any + # other name MUST resolve or we error (prevents the LLM from + # believing it rendered viewpoint X while actually getting free-cam). + if camera_name in (None, "", "default", "free"): + cam_id = -1 + label = "free (default)" + else: + cam_id = mj.mj_name2id(self._world._model, mj.mjtObj.mjOBJ_CAMERA, camera_name) + if cam_id < 0: + return { + "status": "error", + "content": [ + {"text": f"Camera '{camera_name}' not found. Available: {self._list_camera_names()}"} + ], + } + label = camera_name + + if cam_id >= 0: + renderer.update_scene(self._world._data, camera=cam_id) + else: + renderer.update_scene(self._world._data) + + img = renderer.render().copy() + + from PIL import Image + + pil_img = Image.fromarray(img) + buffer = io.BytesIO() + pil_img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + # summary stats so render_all can flag empty-looking frames + # without decoding the PNG a second time. + import numpy as _np + + pixel_var = float(_np.var(img)) + pixel_mean = float(_np.mean(img)) + + return { + "status": "success", + "content": [ + {"text": f"📸 {w}x{h} from '{label}' at t={self._world.sim_time:.3f}s"}, + {"image": {"format": "png", "source": {"bytes": png_bytes}}}, + {"json": {"pixel_variance": pixel_var, "pixel_mean": pixel_mean, "camera": label}}, + ], + } + except Exception as e: + return {"status": "error", "content": [{"text": f"Render failed: {e}"}]} + + def render_depth( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: + """Render depth map from a camera.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + # see note in render() re: None vs 0/negative. + w = self.default_width if width is None else width + h = self.default_height if height is None else height + if err := self._validate_render_dims(w, h): + return err + + try: + # strict camera validation (same policy as render()) + if camera_name in (None, "", "default", "free"): + cam_id = -1 + label = "free (default)" + else: + cam_id = mj.mj_name2id(self._world._model, mj.mjtObj.mjOBJ_CAMERA, camera_name) + if cam_id < 0: + return { + "status": "error", + "content": [ + {"text": f"Camera '{camera_name}' not found. Available: {self._list_camera_names()}"} + ], + } + label = camera_name + + renderer = self._get_renderer(w, h) + if renderer is None: + return { + "status": "error", + "content": [ + { + "text": ( + " Depth rendering unavailable (no OpenGL context). " + "Install EGL or OSMesa for offscreen rendering." + ) + } + ], + } + if cam_id >= 0: + renderer.update_scene(self._world._data, camera=cam_id) + else: + renderer.update_scene(self._world._data) + # MuJoCo prints a one-time ARB_clip_control warning on macOS + # when depth precision is reduced. Capture stderr on the first + # depth render so we can surface the warning in the response + # text (the LLM otherwise never hears about it). + clip_warn = getattr(self, "_depth_warn_text", None) + if clip_warn is None: + import contextlib as _ctx + import io as _io + import sys as _sys + + buf = _io.StringIO() + with _ctx.redirect_stderr(buf): + renderer.enable_depth_rendering() + depth = renderer.render() + renderer.disable_depth_rendering() + captured = buf.getvalue() + # Also forward to the real stderr so logs don't vanish. + if captured and _sys.__stderr__ is not None: + try: + _sys.__stderr__.write(captured) + except Exception: + pass + if "ARB_clip_control" in captured: + self._depth_warn_text = "⚠️ Depth accuracy limited on this GPU (missing ARB_clip_control)" + else: + self._depth_warn_text = "" + clip_warn = self._depth_warn_text + else: + renderer.enable_depth_rendering() + depth = renderer.render() + renderer.disable_depth_rendering() + + text = f"📸 Depth {w}x{h} from '{label}'\nMin: {float(depth.min()):.3f}m, Max: {float(depth.max()):.3f}m" + if clip_warn: + text += f"\n{clip_warn}" + return { + "status": "success", + "content": [ + {"text": text}, + {"json": {"depth_min": float(depth.min()), "depth_max": float(depth.max())}}, + ], + } + except Exception as e: + return {"status": "error", "content": [{"text": f"Depth render failed: {e}"}]} + + def _list_camera_names(self) -> list[str]: + """helper to list all camera names (model-defined + SimCamera aliases) + for error messages when an unknown camera_name is requested.""" + import mujoco as _mj + + names: list[str] = [] + if self._world is not None and self._world._model is not None: + for cid in range(self._world._model.ncam): + raw = _mj.mj_id2name(self._world._model, _mj.mjtObj.mjOBJ_CAMERA, cid) + if raw: + names.append(raw) + # Include SimCamera registry keys (may match model names; dedupe) + for k in self._world.cameras.keys() if self._world else (): + if k not in names: + names.append(k) + return names + + def get_contacts(self) -> dict[str, Any]: + """Return the list of active geom-geom contacts at the current step. + + We run ``mj_forward`` first so the contact list reflects the + current qpos/qvel even immediately after ``reset`` or ``add_robot`` + (without this, stale contacts from the previous step / uninitialised + memory can appear as phantom penetrations at t=0). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + # refresh contact list without advancing time. + mj.mj_forward(model, data) + + def _resolve_geom(gid: int) -> str: + """Prefer the geom name; fall back to its parent body name; then id.""" + gn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, gid) + if gn: + return gn + # Walk to the parent body name. + try: + bid = int(model.geom_bodyid[gid]) + bn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, bid) + if bn: + return f"{bn}/geom_{gid}" + except (IndexError, AttributeError): + pass + return f"geom_{gid}" + + contacts = [] + for i in range(data.ncon): + c = data.contact[i] + g1 = _resolve_geom(c.geom1) + g2 = _resolve_geom(c.geom2) + contacts.append({"geom1": g1, "geom2": g2, "dist": float(c.dist), "pos": c.pos.tolist()}) + + text = f"💥 {len(contacts)} contacts" if contacts else "No contacts." + if contacts: + for c in contacts[:10]: + text += f"\n • {c['geom1']} ↔ {c['geom2']} (d={c['dist']:.4f})" + + return { + "status": "success", + "content": [{"text": text}, {"json": {"contacts": contacts}}], + } + + # Multi-camera capture - Session recording for simulation + + # + # Design: + # - render_all(cameras=None, width=, height=) - single-shot snapshot + # of every camera at current sim_time. One PNG per camera. + # - start_cameras_recording(...) - daemon thread, one imageio writer + # per camera, appends frames at fps. + # - stop_cameras_recording() - flushes writers, returns paths + sizes. + # - get_cameras_recording_status() - frame counts, elapsed, per-cam. + # + # Thread safety: _get_renderer is thread-local (threading.local), so the + # background thread creates its own GL context. No shared state with + # main dispatch thread. + + def _active_camera_list(self, cameras): + """Resolve cameras=None to every camera currently in the world.""" + if self._world is None or self._world._model is None: + return [] + mj = _ensure_mujoco() + model = self._world._model + from_model = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) for i in range(model.ncam)] + from_model = [c for c in from_model if c] + py_side = list(self._world.cameras.keys()) if self._world else [] + all_cams = list(dict.fromkeys(from_model + py_side)) + if cameras is None: + return all_cams + missing = [c for c in cameras if c not in all_cams] + if missing: + logger.warning("Unknown camera(s) requested for capture: %s", missing) + return [c for c in cameras if c in all_cams] + + def render_all(self, cameras=None, width=None, height=None): + """Render every (or a subset of) camera in one call. + + Counterpart to ``render()`` for multi-view workflows - e.g. stereo, + overhead + wrist, or all cameras in a 4-view grid. Each camera ships + as its own ``{"image": {...}}`` block in the response. + + Args: + cameras: list of camera names; None = every camera. + width: per-camera width (defaults to camera's configured width). + height: per-camera height (same). + + Returns: + ``{"status", "content": [{"text": summary}, + {"text": "📸 cam1"}, {"image": {...}}, + {"text": "📸 cam2"}, {"image": {...}}, ...]}`` + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + names = self._active_camera_list(cameras) + if not names: + return {"status": "error", "content": [{"text": "No cameras in scene."}]} + content = [] + ok, failed = 0, 0 + low_var_warnings: list[str] = [] + for cam_name in names: + r = self.render(camera_name=cam_name, width=width, height=height) + if r.get("status") == "success": + ok += 1 + img_block = None + stats = None + for block in r.get("content", []): + if isinstance(block, dict): + if "image" in block and img_block is None: + img_block = block + if "json" in block and stats is None: + stats = block["json"] + if img_block is not None: + label = f"📸 {cam_name}" + # flag near-uniform frames (all black / all clear). + if stats and float(stats.get("pixel_variance", 99)) < 1.0: + warn = f"⚠️ camera '{cam_name}': image appears empty (variance < 1)" + label = f"{label} {warn}" + low_var_warnings.append(warn) + content.append({"text": label}) + content.append(img_block) + else: + failed += 1 + err = r.get("content", [{}])[0].get("text", "?") + content.append({"text": f"{cam_name}: {err}"}) + warn_suffix = f", {len(low_var_warnings)} low-variance" if low_var_warnings else "" + summary = ( + f"📸 Multi-camera snapshot at t={self._world.sim_time:.3f}s: " + f"{ok} ok, {failed} failed, {len(names)} requested{warn_suffix}" + ) + return { + "status": "success" if ok else "error", + "content": [{"text": summary}, *content], + } + + def start_cameras_recording( + self, + cameras=None, + output_dir=None, + fps=30, + width=None, + height=None, + name=None, + max_frames_per_camera=3000, + ): + """Start background capture of one ndarray buffer per camera. + + Strategy: the background thread collects raw RGB frames in memory + (one list per camera). ``stop_cameras_recording`` then flushes each + list to an MP4 on the main thread. This avoids a long-lived ffmpeg + subprocess pipe that would break under concurrent imageio writes + + policy-loop timing jitter. + + Memory cost: H*W*3 bytes * fps * duration * n_cams. For a 2s / 4-cam / + 320x240 / 15fps rollout: ~27 MB. Bounded by ``max_frames_per_camera``. + + Args: + cameras: list of camera names; None = every camera. + output_dir: where to write ``{tag}__{cam}.mp4``. + fps: capture rate. + width/height: per-frame size. + name: filename tag (auto if None). + max_frames_per_camera: safety cap on in-memory buffers. + """ + import os as _os + import threading as _threading + import time as _time + import uuid as _uuid + + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + if getattr(self, "_cams_rec_state", None) and self._cams_rec_state.get("running"): + cur = self._cams_rec_state["name"] + return { + "status": "error", + "content": [{"text": f"Already recording '{cur}'. Call stop_cameras_recording() first."}], + } + + names = self._active_camera_list(cameras) + if not names: + return {"status": "error", "content": [{"text": "No cameras to record."}]} + + out_dir = _os.path.abspath(output_dir or "/tmp/strands_robots/recordings") + _os.makedirs(out_dir, exist_ok=True) + tag = name or f"rec_{_uuid.uuid4().hex[:8]}" + + buffers = {cam: [] for cam in names} + paths = {cam: _os.path.join(out_dir, f"{tag}__{cam}.mp4") for cam in names} + + state = { + "running": True, + "name": tag, + "cameras": names, + "fps": fps, + "width": width, + "height": height, + "buffers": buffers, + "paths": paths, + "errors": dict.fromkeys(names, 0), + "output_dir": out_dir, + "started_at": _time.time(), + "thread": None, + "max_frames": max_frames_per_camera, + } + + def _loop(): + from strands_robots.simulation.policy_runner import _extract_frame_ndarray + + interval = 1.0 / fps + while state["running"]: + t0 = _time.time() + for cam in names: + if not state["running"]: + break + if len(state["buffers"][cam]) >= state["max_frames"]: + continue + try: + r = self.render(camera_name=cam, width=width, height=height) + arr = _extract_frame_ndarray(r) + if arr is not None: + state["buffers"][cam].append(arr) + else: + state["errors"][cam] += 1 + except Exception as e: + state["errors"][cam] += 1 + logger.debug("camera recorder (%s) error: %s", cam, e) + lag = _time.time() - t0 + if lag < interval: + _time.sleep(interval - lag) + + state["thread"] = _threading.Thread(target=_loop, daemon=True) + state["thread"].start() + self._cams_rec_state = state + + msg = ( + f"🎬 Recording {len(names)} camera(s) @ {fps} FPS → {out_dir}\n" + f" tag: {tag}\n" + f" cameras: {', '.join(names)}" + ) + return {"status": "success", "content": [{"text": msg}]} + + def stop_cameras_recording(self): + """Stop capture, flush buffers to MP4 on the MAIN thread. + + Runs ``imageio.get_writer``/``append_data``/``close`` here instead of + the recording thread so the ffmpeg pipe doesn't race with policy + timing jitter. Returns per-camera frame counts and paths. + """ + import os as _os + import time as _time + + state = getattr(self, "_cams_rec_state", None) + if not state or not state.get("running"): + # idempotent - 'already stopped' is a success, not an error. + return {"status": "success", "content": [{"text": "Was not recording cameras."}]} + + state["running"] = False + thread = state.get("thread") + if thread is not None: + thread.join(timeout=5.0) + + try: + import imageio.v2 as imageio + except ImportError: + return { + "status": "error", + "content": [{"text": "imageio not installed. pip install imageio imageio-ffmpeg"}], + } + + elapsed = _time.time() - state["started_at"] + lines = [ + f"🎬 Stopped '{state['name']}' after {elapsed:.1f}s", + f" output_dir: {state['output_dir']}", + ] + artifacts = [] + for cam in state["cameras"]: + frames_buffer = state["buffers"][cam] + path = state["paths"][cam] + errors = state["errors"][cam] + frames_written = 0 + size_kb = 0.0 + if frames_buffer: + writer = imageio.get_writer(path, fps=state["fps"], quality=8, macro_block_size=1) + try: + for arr in frames_buffer: + writer.append_data(arr) + frames_written += 1 + finally: + writer.close() + if _os.path.exists(path): + size_kb = _os.path.getsize(path) / 1024 + lines.append( + f" 📹 {cam:20s} {frames_written:>5d} frames {size_kb:>7.1f} KB " + f"({errors} errors) → {_os.path.basename(path)}" + ) + artifacts.append( + { + "camera": cam, + "path": path, + "frames": frames_written, + "errors": errors, + "size_kb": size_kb, + } + ) + + name = state["name"] + self._cams_rec_state = None + + return { + "status": "success", + "content": [ + {"text": "\n".join(lines)}, + {"json": {"recording": name, "artifacts": artifacts}}, + ], + } + + def get_cameras_recording_status(self): + """Cheap introspection of an ongoing multi-camera recording.""" + import time as _time + + state = getattr(self, "_cams_rec_state", None) + if not state or not state.get("running"): + return {"status": "success", "content": [{"text": "⚪ No active camera recording."}]} + + elapsed = _time.time() - state["started_at"] + lines = [f"🟢 Recording '{state['name']}' for {elapsed:.1f}s @ {state['fps']} FPS"] + for cam in state["cameras"]: + frames = len(state["buffers"][cam]) + lines.append(f" 📹 {cam:20s} {frames:>5d} frames ({state['errors'][cam]} errors)") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} diff --git a/strands_robots/simulation/mujoco/scene_ops.py b/strands_robots/simulation/mujoco/scene_ops.py new file mode 100644 index 0000000..b9f04d4 --- /dev/null +++ b/strands_robots/simulation/mujoco/scene_ops.py @@ -0,0 +1,980 @@ +"""XML round-trip injection/ejection for scene modification. + +Shared helper `_reload_scene_from_xml` handles the common pattern: +save XML → patch paths → modify → reload → copy state → re-discover joints. +""" + +import logging +import os +import re +import shutil +import tempfile +import xml.etree.ElementTree as ET +from typing import Any + +from strands_robots.simulation.models import SimCamera, SimObject, SimRobot, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco +from strands_robots.simulation.mujoco.mjcf_builder import MJCFBuilder, _camera_xyaxes_from_target, _sanitize_name + +logger = logging.getLogger(__name__) + + +def _patch_xml_paths(xml_content: str, robot_base_dir: str) -> str: + """Patch meshdir/texturedir in XML to absolute paths for tmpdir loading. + + Uses ElementTree for consistent XML manipulation throughout scene_ops. + Falls back to the original string if ET parsing fails (e.g. XML fragments). + """ + try: + root = ET.fromstring(xml_content) + except ET.ParseError: + # Fallback for malformed fragments - use regex as last resort + logger.debug("ET parse failed for _patch_xml_paths, using regex fallback") + meshdir_match = re.search(r'meshdir="([^"]*)"', xml_content) + if meshdir_match: + abs_meshdir = os.path.normpath(os.path.join(robot_base_dir, meshdir_match.group(1))) + xml_content = re.sub(r'meshdir="[^"]*"', f'meshdir="{abs_meshdir}"', xml_content) + texdir_match = re.search(r'texturedir="([^"]*)"', xml_content) + if texdir_match: + abs_texdir = os.path.normpath(os.path.join(robot_base_dir, texdir_match.group(1))) + xml_content = re.sub(r'texturedir="[^"]*"', f'texturedir="{abs_texdir}"', xml_content) + return xml_content + + compiler = root.find("compiler") + if compiler is None: + # No compiler element - add one with meshdir + compiler = ET.SubElement(root, "compiler") + # Insert at beginning (after root tag) + root.remove(compiler) + root.insert(0, compiler) + + existing_meshdir = compiler.get("meshdir", "") + compiler.set("meshdir", os.path.normpath(os.path.join(robot_base_dir, existing_meshdir))) + + existing_texdir = compiler.get("texturedir", "") + if existing_texdir or compiler.get("texturedir") is not None: + compiler.set("texturedir", os.path.normpath(os.path.join(robot_base_dir, existing_texdir))) + else: + compiler.set("texturedir", robot_base_dir) + + return ET.tostring(root, encoding="unicode", xml_declaration=False) + + +def _get_abs_meshdir(root: ET.Element) -> str: + """Extract the absolute meshdir from a parsed XML root. + + Returns empty string if no compiler/meshdir is set. + """ + compiler = root.find("compiler") + if compiler is not None: + return compiler.get("meshdir", "") + return "" + + +def _rewrite_mesh_paths( + robot_asset: ET.Element, + robot_meshdir: str, + scene_meshdir: str, +) -> None: + """Rewrite mesh ``file=`` attributes so they resolve under scene_meshdir. + + When merging robot assets into the scene XML, the scene's ```` governs where MuJoCo looks for mesh files. If the + robot's meshdir differs (e.g. ``robot_base/assets/`` vs ``robot_base/``), + each ```` must be adjusted to be correct relative to + the scene's meshdir. + + Strategy: convert each mesh file to an absolute path (via robot_meshdir), + then make it relative to scene_meshdir. If they share no common prefix, + fall back to absolute paths. + """ + if not robot_meshdir or not scene_meshdir: + return + # Normalize: ensure trailing sep for consistent joining + robot_meshdir = os.path.normpath(robot_meshdir) + scene_meshdir = os.path.normpath(scene_meshdir) + + if robot_meshdir == scene_meshdir: + return # No rewriting needed - meshdirs match + + for child in robot_asset: + if child.tag != "mesh": + continue + file_attr = child.get("file") + if not file_attr: + continue + # Build absolute path of the mesh file under robot's meshdir + abs_mesh = os.path.normpath(os.path.join(robot_meshdir, file_attr)) + # Make it relative to the scene's meshdir + try: + rel_path = os.path.relpath(abs_mesh, scene_meshdir) + except ValueError: + # On Windows, relpath fails across drives - use absolute + rel_path = abs_mesh + child.set("file", rel_path) + + # Also rewrite texture file paths that reference files on disk + for child in robot_asset: + if child.tag != "texture": + continue + file_attr = child.get("file") + if not file_attr: + continue + abs_tex = os.path.normpath(os.path.join(robot_meshdir, file_attr)) + try: + rel_path = os.path.relpath(abs_tex, scene_meshdir) + except ValueError: + rel_path = abs_tex + child.set("file", rel_path) + + +def _reload_scene_from_xml(world: SimWorld, scene_path: str) -> bool: + """Reload MuJoCo model from modified XML, preserving state. + + Copies qpos, qvel, ctrl from old model and re-discovers robot joint/actuator IDs. + + before copying existing state into the new MjData we explicitly call + ``mj_resetData`` so that joints NOT present in ``old_model`` (i.e. the + freshly-injected robot's joints) start from a well-defined zero state + rather than whatever garbage pybind11 happened to hand us from fresh + allocation. Old state is then layered on top per-joint-by-name so + previously-existing robots/objects keep their positions. + """ + mj = _ensure_mujoco() + new_model = mj.MjModel.from_xml_path(str(scene_path)) + new_data = mj.MjData(new_model) + + # zero the whole state buffer before copying old-state on top. + # Without this, freshly-added robots show nonzero qpos/qvel/ctrl from + # uninitialised memory and any observation taken before reset() is garbage. + mj.mj_resetData(new_model, new_data) + + # Copy state per-joint by name to handle layout shifts when injected + # bodies land earlier in the body-tree traversal. Flat-index copies + # (qpos[:old_nq]) are unsafe because MuJoCo allocates qpos in + # recursive body-tree order - a new body can shift existing entries. + old_model = world._model + old_data = world._data + for i in range(old_model.njnt): + jnt_name = mj.mj_id2name(old_model, mj.mjtObj.mjOBJ_JOINT, i) + if not jnt_name: + continue + new_jid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if new_jid < 0: + continue # joint removed from scene + # Defensive: skip copy if joint type changed (extremely unlikely in + # inject/eject flow, but prevents stride mismatch → silent corruption). + if old_model.jnt_type[i] != new_model.jnt_type[new_jid]: + continue + # qpos: width depends on joint type (free=7, ball=4, hinge/slide=1) + jnt_type = old_model.jnt_type[i] + qpos_width = {0: 7, 1: 4, 2: 1, 3: 1}.get(int(jnt_type), 1) + old_adr = old_model.jnt_qposadr[i] + new_adr = new_model.jnt_qposadr[new_jid] + new_data.qpos[new_adr : new_adr + qpos_width] = old_data.qpos[old_adr : old_adr + qpos_width] + # qvel: width = joint DoF (free=6, ball=3, hinge/slide=1) + dof_width = {0: 6, 1: 3, 2: 1, 3: 1}.get(int(jnt_type), 1) + old_dof = old_model.jnt_dofadr[i] + new_dof = new_model.jnt_dofadr[new_jid] + new_data.qvel[new_dof : new_dof + dof_width] = old_data.qvel[old_dof : old_dof + dof_width] + + # Copy ctrl per-actuator by name (actuator order may also shift) + for i in range(old_model.nu): + act_name = mj.mj_id2name(old_model, mj.mjtObj.mjOBJ_ACTUATOR, i) + if not act_name: + continue + new_aid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_ACTUATOR, act_name) + if new_aid >= 0: + new_data.ctrl[new_aid] = old_data.ctrl[i] + + mj.mj_forward(new_model, new_data) + + world._model = new_model + world._data = new_data + + # Persist the current scene XML so subsequent mj_saveLastXML calls can + # reset the MuJoCo global state. Without this, any render/renderer + # creation poisons mj_saveLastXML for inject/eject round-trips. + try: + with open(scene_path) as _f: + world._backend_state["xml"] = _f.read() + except OSError: + # Best-effort - don't fail the reload just because we can't read back. + pass + + # Re-discover robot joints/actuators (IDs may shift). + # Try namespaced name first (multi-robot case), fall back to raw. + for robot in world.robots.values(): + robot.joint_ids = [] + robot.actuator_ids = [] + pfx = robot.namespace or "" + for jnt_name in robot.joint_names: + jid = -1 + if pfx: + jid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_JOINT, pfx + jnt_name) + if jid < 0: + jid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jid >= 0: + robot.joint_ids.append(jid) + for i in range(new_model.nu): + jnt_id = new_model.actuator_trnid[i, 0] + if jnt_id in robot.joint_ids: + robot.actuator_ids.append(i) + if not robot.actuator_ids: + # Last-resort fallback: all actuators (single-robot scenes). + if len(world.robots) == 1: + for i in range(new_model.nu): + robot.actuator_ids.append(i) + + return True + + +def _get_robot_base_dir(world: SimWorld) -> str | None: + """Get the base directory for resolving MJCF asset references. + + For multi-robot scenes with different asset directories, use + ``_get_all_robot_base_dirs()`` instead. + + Falls back to the scene base dir when the world was loaded via + ``load_scene`` and has no robots yet (otherwise mesh ``file=`` refs + inside a round-tripped scene XML would fail to resolve under tmpdir). + """ + if world._backend_state.get("robot_base_xml", ""): + return os.path.dirname(os.path.abspath(world._backend_state.get("robot_base_xml", ""))) + scene_base = world._backend_state.get("scene_base_dir", "") + if scene_base and os.path.isdir(scene_base): + return scene_base + return None + + +def _get_all_robot_base_dirs(world: SimWorld) -> list[str]: + """Return a deduplicated list of directories containing robot model files. + + Each robot's ``urdf_path`` points to its MJCF/URDF source. The directory + of each path may contain mesh assets that the scene XML references. + """ + dirs: list[str] = [] + seen: set[str] = set() + for robot in world.robots.values(): + d = os.path.dirname(os.path.abspath(robot.urdf_path)) + if d not in seen: + seen.add(d) + dirs.append(d) + # Also include the legacy single-robot path if set. + legacy = _get_robot_base_dir(world) + if legacy and legacy not in seen: + dirs.append(legacy) + return dirs + + +def _save_and_patch_xml(world: SimWorld, tmpdir: str, filename: str) -> str: + """Save current model to XML in tmpdir and patch asset paths. + + Note: MuJoCo's ``mj_saveLastXML`` is a global function that always + writes the *last loaded* model's XML, ignoring the ``model`` argument. + Any renderer creation (``mj.Renderer``) or ancillary model load between + our last scene compile and this save will poison the global → we get + some *other* model's XML and the inject/eject XML round-trip fails + silently (e.g. "Body 'cube' not found in MJCF XML"). + + To work around this, we first reload our own stored scene XML into the + MuJoCo global state (via ``MjModel.from_xml_string``). The resulting + ``_tmp`` model is discarded - its only purpose is to reset + ``mj_saveLastXML``'s internal pointer. + + Multi-robot note: uses the first robot's base dir for compiler paths. + Individual robot mesh paths are rewritten to absolute during + inject_robot_into_scene (via _rewrite_mesh_paths), so the scene-level + meshdir only needs to resolve for the primary robot. Future enhancement: + convert all mesh paths to absolute during injection to eliminate + first-wins coupling entirely. + """ + mj = _ensure_mujoco() + scene_path = os.path.join(tmpdir, filename) + + stored_xml = world._backend_state.get("xml") + if stored_xml: + _tmp = mj.MjModel.from_xml_string(stored_xml) # noqa: F841 + mj.mj_saveLastXML(scene_path, _tmp) + else: + mj.mj_saveLastXML(scene_path, world._model) + + robot_base_dir = _get_robot_base_dir(world) + if robot_base_dir and os.path.isdir(robot_base_dir): + with open(scene_path) as f: + xml_content = f.read() + xml_content = _patch_xml_paths(xml_content, robot_base_dir) + with open(scene_path, "w") as f: + f.write(xml_content) + + return scene_path + + +def _prefix_robot_names(robot_root: Any, prefix: str) -> None: + """Prefix every named element and reference in a robot MJCF so that + multiple robots with the same ``data_config`` can coexist in one scene. + + Without this, two ``so101`` robots share body names (``base``, ``gripper``, + ...), joint names (``shoulder_pan``, ...), actuator names, etc. MuJoCo + requires all top-level names to be globally unique and rejects the merged + XML with ``"repeated name 'base' in body"``. + + The prefix is applied in-place to: + - element ``name`` attributes (bodies, joints, actuators, sites, geoms, + sensors, tendons, equality constraints, keyframes) + - reference attributes that point *into* the robot namespace: + ``joint``, ``body``, ``site``, ``geom``, ``tendon``, ``actuator``, + ``body1``, ``body2``, ``joint1``, ``joint2`` + + Asset references (mesh, material, texture, hfield) and class references + are NOT prefixed - they are shared by same-config robots (which is the + whole point of the dedupe in assets/defaults). + + Args: + robot_root: The parsed ```` root of the robot XML. + prefix: The robot instance name, used as a namespace prefix. + """ + pfx = f"{prefix}/" + + # Tags whose "name" attribute identifies a unique element in the merged + # scene. Each instance must get prefixed. + _NAMED_TAGS = { + "body", + "joint", + "geom", + "site", + "camera", + "light", + "actuator", + "general", + "motor", + "position", + "velocity", + "sensor", + "force", + "torque", + "jointpos", + "jointvel", + "framepos", + "framequat", + "frameangvel", + "framelinvel", + "framelinacc", + "frameangacc", + "accelerometer", + "gyro", + "magnetometer", + "rangefinder", + "touch", + "subtreecom", + "subtreelinvel", + "subtreeangmom", + "velocimeter", + "user", + "tendon", + "fixed", + "spatial", + "equality", + "connect", + "weld", + "joint_equality", + "tendon_equality", + "key", # keyframes + } + + # Attributes that reference named elements (in the robot namespace). + _REF_ATTRS = { + "joint", + "body", + "site", + "geom", + "tendon", + "actuator", + "body1", + "body2", + "joint1", + "joint2", + "childclass", # default classes - prefixed too since we keep per-robot ones? No - keep shared. + "target", + } + # We don't prefix "childclass" because classes are shared (deduped) across + # same-config robots. Remove it from the set. + _REF_ATTRS.discard("childclass") + + def visit(elem: Any) -> None: + # Rename ``name`` attribute if this tag is in the named set. + if elem.tag in _NAMED_TAGS: + orig = elem.get("name", "") + if orig and not orig.startswith(pfx): + elem.set("name", pfx + orig) + + # Rewrite reference attributes (they point to robot-local elements). + for attr in _REF_ATTRS: + val = elem.get(attr) + if val and not val.startswith(pfx): + elem.set(attr, pfx + val) + + for child in elem: + visit(child) + + # We only want to prefix elements inside: + # - worldbody (bodies, their children) + # - actuator + # - sensor + # - equality + # - tendon + # - keyframe + # We do NOT prefix contents of , , ,