diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index 79b15d9..f389ad9 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -26,6 +26,11 @@ jobs: python-version: '3.12' cache: 'pip' + - name: Install system dependencies (OpenGL for MuJoCo) + run: | + sudo apt-get update + sudo apt-get install -y libosmesa6-dev ffmpeg + - name: Install dependencies run: | pip install --no-cache-dir hatch @@ -35,4 +40,6 @@ jobs: run: hatch run lint - name: Run tests + env: + MUJOCO_GL: osmesa run: hatch run test -x --strict-markers diff --git a/.gitignore b/.gitignore index 2e430c6..28eab63 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,7 @@ dist .strands_robots .coverage .ideation/ +MUJOCO_LOG.TXT +TASKS.md +TASKS_TO_FIX_85.md +.coverage.* diff --git a/AGENTS.md b/AGENTS.md index 6c1ad5a..59275a6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,4 +1,4 @@ -# AGENTS.md — strands-labs/robots +# AGENTS.md - strands-labs/robots ## Overview @@ -11,7 +11,7 @@ > **RULE**: ALWAYS use the project board to track work. When creating follow-up items, > create GitHub issues and add them to this board with Status + Priority set. -> Never track work only in local markdown — the board is the source of truth. +> Never track work only in local markdown - the board is the source of truth. ## Repository Structure @@ -61,20 +61,20 @@ hatch run format # ruff check --fix, ruff format ``` > **Note**: Hatch uses `uv` as installer (`installer = "uv"` in pyproject.toml) for faster -> environment creation. No manual uv install needed — hatch handles it. +> environment creation. No manual uv install needed - hatch handles it. ## Key Conventions -1. **Python 3.12+** — `requires-python = ">=3.12"` (LeRobot >=0.5.0 requires 3.12) -2. **Dependency bounds** — `>=1.0` deps: cap major. `<1.0` deps: cap minor. E.g. `lerobot>=0.5.0,<0.6.0` -3. **`__init__.py` must be thin** — exports only, no logic -4. **Imports at file top** — unless lazy-loading heavy deps with documented reason -5. **Raise on fatal errors** — never warn-and-continue if the system will behave unexpectedly -6. **No silent defaults on error** — returning zero-valued actions on failure is forbidden -7. **Use `require_optional()`** — from `strands_robots/utils.py` for all optional deps -8. **Integration tests required** — each policy needs `tests_integ/` tests with real inference -9. **Test behavior, not implementation** — assert on outputs, not internal state -10. **No dead code** — if it's not called and not part of base class, delete it +1. **Python 3.12+** - `requires-python = ">=3.12"` (LeRobot >=0.5.0 requires 3.12) +2. **Dependency bounds** - `>=1.0` deps: cap major. `<1.0` deps: cap minor. E.g. `lerobot>=0.5.0,<0.6.0` +3. **`__init__.py` must be thin** - exports only, no logic +4. **Imports at file top** - unless lazy-loading heavy deps with documented reason +5. **Raise on fatal errors** - never warn-and-continue if the system will behave unexpectedly +6. **No silent defaults on error** - returning zero-valued actions on failure is forbidden +7. **Use `require_optional()`** - from `strands_robots/utils.py` for all optional deps +8. **Integration tests required** - each policy needs `tests_integ/` tests with real inference +9. **Test behavior, not implementation** - assert on outputs, not internal state +10. **No dead code** - if it's not called and not part of base class, delete it ## PR Workflow @@ -94,7 +94,7 @@ hatch run format # ruff check --fix, ruff format `asimovinc/asimov-v0` which has `sim-model/xmls/asimov.xml` + `sim-model/assets/`. The `_safe_join` helper in `strands_robots/utils.py` guards against traversal (`..`). -- **Auto-download strategy** — every robot with an `asset` block must declare +- **Auto-download strategy** - every robot with an `asset` block must declare exactly one of: 1. `asset.robot_descriptions_module` (preferred) 2. `asset.source` with `type: "github"` diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..450db38 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,268 @@ +# CHANGELOG + +All notable behavioural changes to `strands-robots` are logged here. Follows +[Keep a Changelog](https://keepachangelog.com/) conventions. + +## Unreleased - PR #85 (MuJoCo backend remediation) + +### MJCF builder refactor: string-concat -> MjSpec AST (closes #121, #122-#126) + +The ``MJCFBuilder`` string-concat path and the ``scene_ops`` XML-round-trip +machinery (~700 lines total) are replaced by direct manipulation of +``mujoco.MjSpec`` - the editable MJCF AST shipped with MuJoCo 3.2+. + +What changed under the hood: +- **New module** ``strands_robots/simulation/mujoco/spec_builder.py``. The + ``SpecBuilder`` class owns scene construction + mutation (``build``, + ``add_object``, ``remove_body``, ``add_camera``, ``remove_camera``, + ``attach_robot``, ``from_mjcf_string``, ``from_file``). +- **Deleted**: ``strands_robots/simulation/mujoco/mjcf_builder.py`` (273 + lines of f-string MJCF and the ``_camera_xyaxes_from_target`` helper). +- **Rewrote**: ``strands_robots/simulation/mujoco/scene_ops.py`` from + ~980 lines of tmpdir + ``mj_saveLastXML`` + ``ElementTree`` round-trips + down to ~295 lines that go through ``spec.recompile(model, data)``. +- **Bumped**: ``mujoco>=3.0.0`` -> ``>=3.2.0`` in ``pyproject.toml`` so + ``MjSpec`` is always available. Current hatch env runs 3.8.0. + +Agent-visible wins: +- **New action** ``replace_scene_mjcf(xml=...)`` - atomically replace the + whole scene with agent-authored MJCF. Validated by actually compiling + it, so ````, ````, ````, custom solref/solimp, + sites, hfield, etc. all work without needing new ``SimObject`` shape + vocabulary. On malformed XML returns a clean error dict (no process + abort). +- **``ellipsoid`` shape** now works in ``add_object`` - it's a free + bonus MuJoCo geom type the string-concat builder rejected. +- **Camera orientation** uses ``quat`` (computed via + ``mujoco.mju_mat2Quat``) instead of a hand-rolled ``xyaxes`` string. + Compiled ``cam_mat0`` is numerically identical within ~4e-7. +- **``spec.recompile(model, data)``** preserves existing joint qpos/qvel + for unchanged joints automatically - no manual "copy state by name" + loop. Object freejoints added post-compile get initialised to the + body's ``pos``/``quat``. +- **No more XML injection surface**: names go straight into MjSpec which + validates them itself, so the old ``_sanitize_name`` regex gate + + fuzz test are no longer needed. + +Downstream API is unchanged: ``add_object``, ``add_robot``, ``remove_object``, +``remove_robot``, ``add_camera``, ``remove_camera``, ``load_scene`` all keep +their tool-action signatures. Tests that asserted on exact XML strings +were rewritten to assert on compiled ``MjModel`` properties (``cam_mat0``, +``mj_name2id``) so they are representation-agnostic. + +Known constraint: ``remove_robot`` now rebuilds the scene from scratch +(drops joint qpos state) rather than going through ``spec.delete()`` on +attached bodies. This sidesteps a MuJoCo 3.8 double-free bug where +``spec.delete(attached_body)`` + interpreter shutdown crashes. Trade-off +is documented in ``scene_ops.eject_robot_from_scene``. + +### Breaking + +These changes tighten the MuJoCo AgentTool contract. Legacy callers that +silently worked by accident will now receive a clear error instead: + +- **Router input validation**: The ``_dispatch_action`` router rejects any + top-level parameter that isn't declared on the target method. Passing + ``step(num_steps=5)`` (wrong name) or ``set_gravity(device="mps")`` + (stray kwarg) now errors with *"Unknown parameter X for action Y. + Valid: [...]"* instead of silently dropping the value. Methods whose + Python signature includes ``**kwargs`` (e.g. ``add_object``) keep their + pass-through semantics. +- **Missing required args**: produce *"Action X requires parameter Y."* + instead of a raw Python ``TypeError``. +- **Vector dimension validation**: ``position``, ``target``, ``origin``, + ``force``, ``torque``, ``gravity``, ``direction``, ``point``, ``orientation`` + (quaternion), and ``color`` (rgba) all validated for length + numeric + dtype before reaching numpy/MuJoCo. +- **Camera orientation**: ``add_camera(target=[x,y,z])`` is now honoured + by baking ``xyaxes`` into the MJCF ````. Previously the target + was silently dropped and every custom camera rendered a default view. + Degenerate case (``target == position``) errors. +- **Render camera validation**: ``render(camera_name="missing")`` errors + with *"Camera 'missing' not found."* instead of silently falling back + to the free camera while claiming to render from the named one. +- **Raycast zero-direction guard**: ``raycast(direction=[0,0,0])`` now + errors with *"direction vector is zero-length"*. Previously MuJoCo's + C-level ``mj_ray`` would abort the Python process. +- **apply_force requires a non-zero vector**: passing neither ``force`` + nor ``torque`` (or both zero) errors. Previously the call silently + succeeded with no effect. +- **step(n_steps<0)** rejected (previously it corrupted ``step_count``). +- **Negative mass / timestep / size** rejected per shape; previously + ``set_body_properties(mass=-1)`` and ``set_timestep(-0.01)`` silently + succeeded. +- **Plane objects auto-static**: ``add_object(shape="plane")`` now forces + ``is_static=True`` (planes are infinite in MuJoCo). Explicit + ``is_static=False`` on a plane is a hard error. +- **Duplicate camera name** rejected. Previously a second ``add_camera`` + with an existing name silently overwrote the registry entry while + leaving the old camera in the XML - ghost behaviour. Use + ``remove_camera`` + ``add_camera`` to replace. +- **stop_policy(robot_name='')** errors with *"stop_policy requires + 'robot_name'."* instead of silently matching the first robot. +- **eval_policy** requires an explicit ``robot_name``. Default + ``n_episodes`` lowered from 10 to 1. +- **register_urdf** validates the path: file must exist, be a file, and + be readable. Previously bad paths were cached and blew up later. + +### Recording backend split + +- ``start_recording`` (LeRobotDataset: parquet + per-camera MP4) still + requires the ``[lerobot]`` extra. Its error message when lerobot is + missing now points callers at ``start_cameras_recording`` for plain + MP4 (which runs under ``[sim-mujoco]`` alone via imageio-ffmpeg). +- No API change - the fix is informational. + +### Resource hygiene + +- ``destroy()`` and ``cleanup()`` now close renderers on the main thread + and empty the TLS cache. Previously each ``create_world/destroy`` + cycle leaked one ``mujoco.Renderer`` + its GL context (~33 MB per + cycle measured). Worker-thread renderers still release themselves on + thread teardown (we avoid cross-thread ``close()`` to prevent + ``cgl.free()`` SIGSEGVs on macOS). +- ``get_mass_matrix`` and ``get_contacts`` run ``mj_forward`` first so + values are valid immediately after a ``reset`` or ``add_robot`` + (previously returned stale / uninitialised memory). + +### Concurrency guards + +Write-mutations are now refused while a policy is running on any robot +in the world. Previously these could race the policy worker thread and +produce undefined behaviour or SIGSEGV: + + reset, set_gravity, set_timestep, set_joint_positions, + set_joint_velocities, apply_force, set_body_properties, + set_geom_properties, load_state, randomize, move_object + +The error now lists *which* robot(s) are active so the LLM can +``stop_policy`` on each without guessing: *"Cannot 'X' while a policy +is running on 'armA', 'armB'. Stop it first: action='stop_policy'."* + +### Concurrent per-robot policies (GH #114) + +Multiple ``start_policy`` calls on *different* robots now run +concurrently. MuJoCo physics is still serialized via ``self._lock`` +(``mj_step`` and ``ctrl[]`` writes are not thread-safe for concurrent +mutation), but each policy owns a disjoint slice of ``data.ctrl[]`` so +two VLA arms can operate in the same scene without semantic conflict. + +- ``start_policy("armA")`` + ``start_policy("armB")`` both succeed. + Second call no longer hits a global "policy already running" gate. +- ``start_policy`` on the *same* robot while its policy is active + still errors (unchanged). +- ``remove_robot("X")`` now gracefully stops X's own policy before + removing, instead of requiring a prior ``stop_policy("X")``. Still + errors if a *different* robot has an active policy (XML round-trip + invalidates cached IDs everywhere). +- New action ``list_policies_running`` returns the names of robots + with live policies. Prunes completed Futures as a side-effect. +- Completed policy Futures are no longer retained forever in + ``_policy_threads`` (GH #120 companion fix). + +### Policy-hook robustness (GH #117) + +``PolicyRunner.run`` previously caught *all* ``on_frame`` exceptions at +WARN level and kept iterating. A recording hook with a typo'd observation +key would log 500 lines and produce an empty dataset. Now we count +*consecutive* failures and abort the episode after a threshold (default +5, tunable via new ``max_onframe_failures`` kwarg). + +- A single transient failure still logs + continues; counter resets on + the next successful call. +- ``N`` consecutive failures raise ``RuntimeError`` so ``run()`` returns + ``status='error'`` with a clear message, preventing silent dataset + corruption. + +### Cleanup graceful shutdown (GH #116) + +``Simulation.cleanup()`` no longer races the policy worker. Previously +cleanup set ``self._world = None`` and called ``executor.shutdown(wait=False)`` +nearly simultaneously - a policy still inside ``mj_step`` segfaulted on +freed arrays. Now cleanup: + +1. Signals every live policy to stop (``policy_running = False``). +2. Awaits each outstanding Future with a bounded timeout (default 5s, + overridable via new ``cleanup(policy_stop_timeout=...)`` kwarg). +3. Only AFTER workers unwind do we null ``self._world`` and tear down + renderers / viewer / executor. + +Wedged workers that don't stop in time get logged as a warning - cleanup +proceeds rather than hanging the host process on exit. + +### Error message consistency + +- All "no world" paths return the same string: + *"No world. Call create_world (or load_scene) first."* +- Unknown-name errors use a uniform `` 'X' not found.`` shape + (Robot / Object / Body / Geom / Joint / Sensor / Camera / Checkpoint). +- ``stop_recording``, ``stop_cameras_recording``, ``stop_policy``, + ``close_viewer`` are now **idempotent**: calling them when nothing + is running returns ``status="success"`` with a *"Was not ..."* message + so callers can invoke them unconditionally. +- ``get_recording_status`` returns success in every lifecycle state + (no world / not recording / recording). + +### Deprecations + +- **add_robot name-as-registry fallback**: passing ``name="my_bot"`` + without ``urdf_path`` or ``data_config`` used to resolve ``my_bot`` in + the model registry. This now fires a ``DeprecationWarning``. Use + ``add_robot(name="...", data_config="")`` instead. Will + be removed next major release. + +### New / extended actions + +- ``forward_kinematics(body_name="X")`` filters to a single body. +- ``get_features(robot_name="X")`` filters to a single robot's joints + and actuators. +- ``set_geom_properties(geom_name="X")`` accepts the bare object name + as an alias for the injected ``"{name}_geom"``. +- ``render_all`` flags cameras whose frame has near-zero pixel variance + (``"⚠️ camera 'X': image appears empty (variance < 1)"``). +- ``render_depth`` surfaces MuJoCo's one-time ``ARB_clip_control`` + warning in the response text on macOS, so the LLM knows when depth + accuracy is reduced. +- ``render`` / ``render_depth``: width/height validated up front; + oversized requests get a plain-English message naming the actual + framebuffer cap (````) instead of MuJoCo's raw + error. +- ``run_policy`` / ``start_policy``: accept optional ``n_steps`` + (primary) or legacy ``max_steps`` as an alternative to + ``duration``+``control_frequency``. ``duration = n_steps / + control_frequency`` when ``n_steps`` is set. +- **New ``list_policies_running``** action returns the names of robots + with a live policy - pairs with the new concurrent-policy support + (see *Concurrent per-robot policies* above). +- ``randomize(randomize_physics=True)`` now reports per-body mass scales + and per-geom friction scales in the response (not just range + endpoints). +- ``get_contacts`` resolves unnamed geoms to + ``"/geom_"`` so contact pairs are always human-readable. +- ``get_sensor_data(sensor_name="X")`` on a model with no sensors now + distinguishes *"Sensor 'X' not found. Model has no sensors."* from + the generic "no sensors in model" success. + +### Tests + +- New: ``tests/simulation/mujoco/test_agenttool_contract.py`` - ~50 + tests that lock in router validation, tool_spec ↔ method parity, + unified error messages, idempotent stop family, ``mj_forward`` before + reads, render-dim validation, feature filters, camera duplicate + policy, plane auto-static, policy horizon unification, and more. +- New: ``tests/simulation/mujoco/test_renderer_hygiene.py`` - 4 tests + asserting TLS cache is emptied on ``destroy``, renderer reuse works + for identical ``(w,h)``, and ``create_world`` after ``destroy`` + rebuilds cleanly. +- New: ``tests/simulation/mujoco/test_recording_backends.py`` - 2 tests + (one skipped when ``lerobot`` IS installed) pinning the + MP4-without-lerobot backend. +- New: ``tests/simulation/mujoco/test_input_validation.py`` - 11 tests + for step/raycast/apply_force validation. +- New: ``tests_integ/test_resource_hygiene.py`` - 3 integration tests + (require ``psutil``): 50 create/destroy cycles grow RSS < 50 MB; 500 + renders at fixed dims grow RSS < 100 MB; TLS cache cleared on destroy. + +Test count: **256 → 362** (+106 new regression tests), zero +regressions. ``hatch run lint`` (ruff + mypy) clean across 102 source +files. diff --git a/IDEA.md b/IDEA.md new file mode 100644 index 0000000..6a8bfbd --- /dev/null +++ b/IDEA.md @@ -0,0 +1,361 @@ +# IDEA: MJCF AST Refactor — Replace String-Concat Builder with `mujoco.MjSpec` + +> **Status:** Proposal / exploration. Nothing implemented. Safe to hand to an autonomous agent as a scoped, staged refactor. +> **Target:** `strands_robots/simulation/mujoco/mjcf_builder.py` + large chunks of `scene_ops.py` +> **Bump required:** `mujoco>=3.2.0` (currently pinned `>=3.0.0,<4.0.0` in `pyproject.toml` — already installed at **3.8.0** in hatch env) + +--- + +## TL;DR + +`mujoco.MjSpec` is the official editable MJCF AST shipped with MuJoCo 3.2+. We currently build MJCF by string-concatenating f-strings and mutate scenes by round-tripping XML through `xml.etree.ElementTree`. Switching to `MjSpec` deletes ~600 lines of hand-rolled string munging, kills several known bug classes (camera orientation, keyframe-dim mismatch, mesh-path patching), and unlocks two new capabilities: + +1. **Agent-authored raw MJCF** — validated by *actually compiling it*, with a clean fallback path. +2. **Fine-grained live mutation** — add/remove/modify bodies, geoms, sensors, tendons, equalities without a tmpdir + regex roundtrip. + +--- + +## Current state (what's hardcoded today) + +### File: `strands_robots/simulation/mujoco/mjcf_builder.py` (273 lines) + +Three things baked in: + +1. **String-concatenated XML.** Every MJCF element is a Python f-string: + ```python + parts.append(f'') + ``` + Every new element type = new f-string. Names require a custom `_sanitize_name` regex to prevent XML injection — which only exists because we're doing string concat. + +2. **Shape vocabulary is frozen.** `_object_xml()` is an `elif` ladder over + `box / sphere / cylinder / capsule / mesh / plane`. `SimObject` has a + single `shape: str` + `size: list[float]` — no room for `ellipsoid`, + `hfield`, sites, tendons, equality constraints, pairs, sensors, custom + materials per-object, friction/solref/solimp tuning, etc. + +3. **`_camera_xyaxes_from_target`** — 72 lines of linear algebra + bug-fix + commentary that exist *solely* because MuJoCo's `mode="fixed"` cameras + ignore the `target` attribute, so we hand-compute `xyaxes`. With MjSpec + we just set `cam.targetbody` or let the compiler emit the quat. + +### File: `strands_robots/simulation/mujoco/scene_ops.py` (980 lines) + +The XML round-trip machinery: + +| Helper | Lines | What it does | +|---|---|---| +| `_patch_xml_paths` | ~40 | Rewrites `meshdir`/`texturedir` to absolute paths after `mj_saveLastXML` | +| `_get_abs_meshdir` / `_rewrite_mesh_paths` | ~60 | Patches `` paths across robots loaded from different base dirs | +| `_prefix_robot_names` | ~120 | Tree-walks an MJCF root and namespaces every `name=` attribute | +| `_namespace_robot_default_classes` | ~60 | Namespaces `` blocks to avoid collisions on merge | +| `_collect_existing_class_names` | ~15 | Class-name collision avoidance helper | +| `inject_robot_into_scene` | ~50 | Load URDF → save → patch paths → prefix names → merge into scene XML → reload | +| `inject_object_into_scene` | ~34 | `ET.parse` → find `` → append → delete `` (freejoint adds qpos) → write → reload | +| `eject_body_from_scene` | ~45 | `ET.parse` → find body by name → remove → write → reload | +| `eject_robot_from_scene` | ~70 | Same, but also cleans actuators/sensors/equality referencing the robot | +| `inject_camera_into_scene` | ~44 | `ET.parse` → append `` → reload | + +All of this is reimplementing what `MjSpec` gives for free. + +### Downstream consumers +- `simulation.py:346` — `xml = MJCFBuilder.build_objects_only(self._world)` +- `simulation.py:292` — `_recompile_world()` rebuilds from scratch via `MJCFBuilder` + `mj.MjModel.from_xml_string` +- `scene_ops.py:790` — `MJCFBuilder._object_xml(obj, indent=4)` called inside `inject_object_into_scene` +- (Search for `MJCFBuilder` to find all call sites.) + +--- + +## Target state: `MjSpec`-backed world + +### Core idea +`SimWorld` grows one new field via `_backend_state`: +```python +world._backend_state["spec"]: mujoco.MjSpec +``` +`world._model` stays an `MjModel` (unchanged public contract). The *source of truth* for scene structure is the `MjSpec`. The model is derived from it via `spec.compile()`. + +### New module: `spec_builder.py` (replaces `mjcf_builder.py`) + +Sketch (pseudocode — agent should flesh out): + +```python +import mujoco +from mujoco import mjtGeom + +SHAPE_MAP = { + "box": mjtGeom.mjGEOM_BOX, + "sphere": mjtGeom.mjGEOM_SPHERE, + "cylinder": mjtGeom.mjGEOM_CYLINDER, + "capsule": mjtGeom.mjGEOM_CAPSULE, + "ellipsoid":mjtGeom.mjGEOM_ELLIPSOID, # bonus — free with this refactor + "mesh": mjtGeom.mjGEOM_MESH, + "plane": mjtGeom.mjGEOM_PLANE, +} + +class SpecBuilder: + @staticmethod + def build(world: SimWorld) -> mujoco.MjSpec: + spec = mujoco.MjSpec() + spec.compiler.angle = mujoco.mjtAngle.mjANGLE_RADIAN + spec.compiler.autolimits = True + spec.option.timestep = world.timestep + spec.option.gravity = world.gravity + + # visual / asset / lights / ground + SpecBuilder._add_defaults_and_assets(spec, world) + + # cameras — use targetbody or add_frame trick instead of xyaxes math + for cam in world.cameras.values(): + SpecBuilder._add_camera(spec, cam) + + # objects + for obj in world.objects.values(): + SpecBuilder._add_object(spec, obj) + + # robots — via attach() (see compose below) + for robot in world.robots.values(): + SpecBuilder._attach_robot(spec, robot) + + return spec + + @staticmethod + def _add_object(spec, obj: SimObject): + body = spec.worldbody.add_body( + name=obj.name, pos=obj.position, quat=obj.orientation, + ) + if not obj.is_static: + body.add_freejoint(name=f"{obj.name}_joint") + # inertial auto-computed from geoms + mass in MjSpec + geom_kwargs = dict( + name=f"{obj.name}_geom", + type=SHAPE_MAP[obj.shape], + rgba=obj.color, + ) + if obj.shape == "mesh": + geom_kwargs["meshname"] = f"mesh_{obj.name}" + else: + geom_kwargs["size"] = _normalize_size(obj.shape, obj.size) + body.add_geom(**geom_kwargs) +``` + +No `_sanitize_name` — `MjSpec` validates names itself. +No `_camera_xyaxes_from_target` — use `cam.targetbody` or set `cam.quat` from a helper that MuJoCo's own code verifies. +No f-strings, no escaping, no regex. + +### Robot composition (replaces `compose_multi_robot_scene` + all the prefix helpers) + +Current ~200 lines of `_prefix_robot_names` + `_namespace_robot_default_classes` collapse to: + +```python +robot_spec = mujoco.MjSpec.from_file(robot.urdf_path) # URDF → spec +frame = scene_spec.worldbody.add_frame( + pos=robot.position, quat=robot.orientation, +) +scene_spec.attach(robot_spec, prefix=f"{robot.name}_", frame=frame) +``` + +`attach()` handles: +- Name prefixing across bodies, joints, geoms, actuators, sensors, sites. +- Default class namespacing. +- Asset deduplication (meshes, textures, materials). +- Keyframe merging (or not — configurable). + +### Live mutation (replaces `inject_*` / `eject_*`) + +`scene_ops.inject_object_into_scene` before: + +```python +# tmpdir, save XML, parse with ET, find worldbody, append child, +# delete keyframes, write, reload from path, copy state, re-discover joints +``` + +After: + +```python +def inject_object_into_scene(world, obj): + spec = world._backend_state["spec"] + SpecBuilder._add_object(spec, obj) + world._model, world._data = spec.recompile(world._model, world._data) + # recompile preserves qpos for unchanged joints; new freejoint qpos = pos/quat from body + return True +``` + +`eject_body_from_scene`: + +```python +def eject_body_from_scene(world, body_name): + spec = world._backend_state["spec"] + body = spec.body(body_name) # raises KeyError if missing + body.delete() + world._model, world._data = spec.recompile(world._model, world._data) + return True +``` + +### Agent-authored raw MJCF (the *new* capability) + +Add a third tool-facing entry point: + +```python +def replace_scene_mjcf(world, xml: str): + """Atomically swap the whole scene to agent-written MJCF. + Validated by actually compiling it. Raises on failure + with the MuJoCo compiler error verbatim. + """ + new_spec = mujoco.MjSpec.from_string(xml) + new_model = new_spec.compile() # raises if invalid + new_data = mujoco.MjData(new_model) + world._backend_state["spec"] = new_spec + world._model, world._data = new_model, new_data + +def patch_scene_mjcf(world, ops: list[dict]): + """Apply a list of structured ops to the live spec. + ops = [ + {"op": "add_body", "parent": "world", "name": "foo", "pos": [...]}, + {"op": "add_geom", "body": "foo", "type": "box", "size": [...], "rgba": [...]}, + {"op": "set_attr", "path": "body/foo", "attr": "pos", "value": [1,0,0]}, + {"op": "delete", "path": "body/foo"}, + ] + """ + spec = world._backend_state["spec"] + for op in ops: + _apply_op(spec, op) # small dispatcher + world._model, world._data = spec.recompile(world._model, world._data) +``` + +Both compose cleanly with the `SimObject`/`SimRobot` dataclasses — those remain the *easy path*. Raw MJCF is the *escape hatch*, matching the `_backend_state` pattern already documented in `models.py`. + +--- + +## Work breakdown (staged, safe) + +### Stage 0 — Prep +- [ ] Bump `mujoco>=3.2.0` in `pyproject.toml` (`sim-mujoco` optional group). Check current envs; most already have 3.8. +- [ ] Add `strands_robots/simulation/mujoco/spec_builder.py` skeleton. No call sites yet. +- [ ] Unit test: `test_spec_builder_smoke.py` — create `SimWorld` with 2 objects, 1 camera, build spec, compile, assert `model.nbody >= 3`, `model.ncam == 1`, `model.ngeom >= 2`. + +### Stage 1 — Parity for object-only scenes (no robots) +- [ ] Implement `SpecBuilder.build(world)` covering everything `MJCFBuilder.build_objects_only` does: visual, asset, lights, ground, cameras, objects. +- [ ] Add feature flag `STRANDS_SIM_USE_MJSPEC=1` in `simulation.py:_recompile_world()` that routes to `SpecBuilder.build(self._world).compile()` vs. the old string path. +- [ ] Ensure hatch env tests pass under *both* code paths. +- [ ] Add a spec-focused test that asserts on spec structure (e.g. `spec.body("cube_1").pos == [...]`), not XML strings. + +### Stage 2 — Camera orientation +- [ ] In `SpecBuilder._add_camera`, use `cam.targetbody` when a target is given and a named body at that location exists; otherwise set `cam.quat` from a helper that *uses MuJoCo's own math*. +- [ ] Delete `_camera_xyaxes_from_target` from `mjcf_builder.py` once unused. +- [ ] Port tests in `tests/simulation/test_mujoco_cameras.py` (if they exist — verify). + +### Stage 3 — Single-robot attach +- [ ] `SpecBuilder._attach_robot(spec, robot)` using `spec.attach(robot_spec, prefix=..., frame=...)`. +- [ ] Verify joints, actuators, sensors discovered via existing `_discover_*` helpers in `simulation.py` still work (they read from `model`, which is identical downstream). +- [ ] Remove `_save_and_patch_xml` dependency for single-robot scenes. + +### Stage 4 — Multi-robot compose +- [ ] Replace `compose_multi_robot_scene` with `SpecBuilder.build(world)` + per-robot `attach()`. +- [ ] Delete `_prefix_robot_names`, `_namespace_robot_default_classes`, `_collect_existing_class_names` once all consumers migrated. +- [ ] Confirm namespace conventions (`{robot_name}_` prefix) match what downstream code reads (grep for joint/actuator name assumptions). + +### Stage 5 — Live inject/eject via spec mutation +- [ ] Port `inject_object_into_scene` to `spec.worldbody.add_body(...)` + `spec.recompile(model, data)`. +- [ ] Port `eject_body_from_scene` to `spec.body(name).delete()` + recompile. +- [ ] Port `inject_camera_into_scene`, `eject_robot_from_scene` similarly. +- [ ] Delete `_patch_xml_paths`, `_rewrite_mesh_paths`, `_get_abs_meshdir`, `_save_and_patch_xml` once unused. +- [ ] Handle the `keyframe` qpos-mismatch issue: MjSpec has `spec.keys` — clear or resize appropriately on recompile. + +### Stage 6 — Agent-authored raw MJCF +- [ ] Add `replace_scene_mjcf(world, xml)` and `patch_scene_mjcf(world, ops)` in `scene_ops.py`. +- [ ] Expose as Strands `@tool` decorators in `tool_spec.json` + a new tool module. Document clearly: "escape hatch, validated by compilation." +- [ ] Integration test: agent writes a scene with a `` element (something `SimObject` can't express), confirms it compiles and simulates. + +### Stage 7 — Cleanup +- [ ] Remove feature flag once all stages green in CI. +- [ ] Delete `mjcf_builder.py`. +- [ ] Audit `scene_ops.py` — should shrink from ~980 lines to ~400. +- [ ] Update `AGENTS.md` if the scene-building conventions changed. + +--- + +## Risks & mitigations + +1. **`recompile(model, data)` preserves qpos only when joint dims unchanged.** + *Mitigation:* Adding a freejoint changes `nqpos`. Current code deletes keyframes. Spec version: after recompile, re-inject qpos for unchanged joints by name, leave new joints at their default (body `pos`/`quat`). + +2. **`spec.to_xml()` is canonical, not byte-identical to input.** + *Mitigation:* Any test asserting exact XML strings is wrong and should be rewritten against spec structure or compiled model properties. Grep for `assert.*xml` in tests. + +3. **`attach()` default-class naming differs from current `_namespace_robot_default_classes`.** + *Mitigation:* Concrete difference: `attach(prefix="r1_")` creates `r1_main` default class. Current code may use a different pattern. Find all places that read default-class names (likely none in Python — defaults are consumed by MuJoCo's compiler) and verify. Add integration test with 2 robots from different URDFs to catch regressions. + +4. **MuJoCo compiler errors are C-level and sometimes cryptic.** + *Mitigation:* Wrap `spec.compile()` in `scene_ops.replace_scene_mjcf` with `try/except ValueError` and add context: which spec, which body, what op was being applied. + +5. **PR #85 is actively modifying `scene_ops.py`.** + *Mitigation:* Coordinate with @yinsong1986 before Stage 5. Stages 0–4 are mostly additive and land-safe. + +6. **MjSpec API churn 3.2 → 3.8.** + *Mitigation:* The surface we need (`from_string`, `from_file`, `add_body`, `add_geom`, `attach`, `compile`, `recompile`, `to_xml`, `body.delete`) has been stable since 3.2. Pin `>=3.2.0,<4.0.0` to be safe. + +--- + +## Non-goals + +- Not rewriting `physics.py`, `rendering.py`, `randomization.py`, `recording.py` — those consume `MjModel`/`MjData`, which stay unchanged. +- Not changing the Strands tool surface for existing operations — `add_object`, `spawn_robot`, etc. keep their signatures. +- Not changing `SimWorld` / `SimObject` / `SimRobot` / `SimCamera` public fields. +- Not touching Isaac Sim / PyBullet backends (they don't exist yet, but the `SimEngine` ABC is unaffected). + +--- + +## Success criteria + +- `mjcf_builder.py` deleted. +- `scene_ops.py` under 500 lines. +- All existing unit + integration tests pass. +- One new integration test proves an agent can author raw MJCF including an element not expressible via `SimObject` (e.g. `` or ``). +- No test asserts on exact XML strings. +- `grep -r "f'<" strands_robots/simulation/mujoco/` returns nothing. + +--- + +## Appendix: proof-of-life snippet + +Verified on this host (`mujoco==3.8.0`, 2026-05-05): + +```python +import mujoco + +# Parse → edit → recompile → serialize +spec = mujoco.MjSpec.from_string('') +alice = spec.worldbody.add_body(name='alice', pos=[0, 0, 1]) +alice.add_freejoint() +alice.add_geom(name='alice_geom', + type=mujoco.mjtGeom.mjGEOM_SPHERE, + size=[0.1, 0, 0], rgba=[1, 0, 0, 1]) +model = spec.compile() # validates; raises on error +assert model.nbody == 2 + +# Attach a second spec (composition — replaces ~200 lines) +robot = mujoco.MjSpec.from_string( + '' + '' + '') +frame = spec.worldbody.add_frame(pos=[1, 0, 0]) +spec.attach(robot, prefix='r1_', frame=frame) +# Emits: body name="r1_arm", geom name="r1_link", plus a "r1_main" default class. + +print(spec.to_xml()) # canonical round-trip +``` + +--- + +## Handoff for autonomous agent + +When executing this plan: + +1. **Work on a feature branch** off `main` — `feat/mjspec-refactor`. +2. **Stage by stage, one PR per stage.** Each PR must be green in CI and reviewable in isolation. +3. **Keep the feature flag alive** until Stage 7. Both code paths tested in CI. +4. **Track progress on the project board** per `AGENTS.md` rule ("the board is the source of truth"): https://github.com/orgs/strands-labs/projects/2 +5. **Do not** touch URDF parsing, policy providers, teleoperation, or calibration. Stay inside `simulation/mujoco/`. +6. **Do not** delete anything in `scene_ops.py` until every downstream caller is migrated — audit with `grep -r` before each deletion. +7. **Ask before** bumping any dependency bound other than `mujoco`. diff --git a/README.md b/README.md index 0a93a93..7e4591d 100644 --- a/README.md +++ b/README.md @@ -493,7 +493,7 @@ agent.tool.gr00t_inference(action="stop", port=8000) | Variable | Description | Default | |----------|-------------|---------| | `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (MJCF, meshes) | `~/.strands_robots/assets/` | -| `GROOT_API_TOKEN` | API token for GR00T inference service | — | +| `GROOT_API_TOKEN` | API token for GR00T inference service | - | ### Cache Directory @@ -511,6 +511,108 @@ To clear the cache: `rm -rf ~/.strands_robots/assets/` To change the cache location: `export STRANDS_ASSETS_DIR=/path/to/custom/dir` +## Simulation (MuJoCo) + +`strands-robots` ships a MuJoCo-backed simulation AgentTool - 58 actions +exposed to any Strands agent for world composition, physics, policy +execution, and video/dataset recording. + +### Install + +```bash +pip install "strands-robots[sim-mujoco]" +# For LeRobotDataset recording (parquet + training data): +pip install "strands-robots[sim-mujoco,lerobot]" +``` + +### Quick start + +```python +from strands_robots.simulation import Simulation + +sim = Simulation(tool_name="sim", mesh=False) +sim.create_world() +sim.add_robot(name="arm", data_config="so100") +sim.add_object(name="cube", shape="box", position=[0.3, 0, 0.05]) +sim.add_camera(name="topdown", position=[0, 0, 1.5], target=[0, 0, 0]) + +sim.run_policy(robot_name="arm", policy_provider="mock", n_steps=200, + control_frequency=50.0, fast_mode=True) + +frame = sim.render(camera_name="topdown") # returns {status, content:[text, image]} +``` + +### 58 actions grouped + +- **World & objects**: `create_world`, `load_scene`, `add_robot`, + `add_object`, `move_object`, `list_objects`, `list_robots`, + `remove_robot`, `remove_object`, `destroy`, `reset`, `get_state`, + `save_state`, `load_state`, `list_checkpoints`. +- **Physics**: `step`, `set_timestep`, `set_gravity`, `apply_force`, + `raycast`, `multi_raycast`, `set_body_properties`, + `set_geom_properties`, `get_body_state`, `get_joint_state`, + `set_joint_positions`, `set_joint_velocities`, `forward_kinematics`, + `get_mass_matrix`, `inverse_dynamics`, `get_total_mass`, + `get_jacobian`, `get_energy`, `get_contacts`, `get_sensor_data`. +- **Cameras & rendering**: `add_camera`, `remove_camera`, `render`, + `render_depth`, `render_all`, `start_cameras_recording`, + `stop_cameras_recording`, `get_cameras_recording_status`. +- **Policy**: `start_policy`, `run_policy`, `stop_policy`, + `replay_episode`, `eval_policy`. +- **Randomization**: `randomize`. +- **Recording (LeRobotDataset)**: `start_recording`, `stop_recording`, + `get_recording_status`. +- **Introspection & util**: `get_features`, `list_urdfs`, `register_urdf`, + `export_xml`, `open_viewer`, `close_viewer`. + +### Common footguns + +- **Planes must be static.** `add_object(shape="plane")` auto-sets + `is_static=True`. Passing `is_static=False` on a plane is a hard error + (MuJoCo planes are infinite and can't have dynamic mass). +- **Camera orientation.** Pass `target=[x,y,z]` to look at a point - + without it the camera faces forward by default. `target == position` + errors. +- **MP4 vs dataset recording.** `start_cameras_recording` writes plain + MP4 per-camera and runs under `[sim-mujoco]` alone. `start_recording` + writes a LeRobotDataset (parquet + MP4 + schema) and requires the + `[lerobot]` extra. +- **Policy running → mutations blocked.** While a policy runs on any + robot, state-mutating actions (`reset`, `set_gravity`, joint setters, + `apply_force`, `set_body_properties`, `set_geom_properties`, + `load_state`, `randomize`, `move_object`) error with *"Cannot 'X' + while a policy is running."* Stop it first with + `stop_policy(robot_name='...')`. +- **Horizon parameters.** `run_policy` accepts either `duration` + + `control_frequency` (real-time) OR `n_steps` + `control_frequency` + (step-count). Pass `fast_mode=True` to skip the between-step sleep + during batch eval / data collection. +- **Name collisions.** Objects, bodies, robots, and cameras share the + MuJoCo name table. Robot joints and actuators are auto-namespaced as + `{robot_name}/{joint}` in multi-robot scenes. Object geoms are + injected as `{object_name}_geom`; `set_geom_properties` accepts the + bare object name as an alias. +- **Oversized render**: MuJoCo's offscreen framebuffer is capped by + `` in MJCF. Requesting a bigger + render now errors with a plain message naming the cap - either lower + the request or rebuild the model with larger dims. + +### Self-healing features + +- Unknown parameters are rejected with *"Unknown parameter X for action + Y. Valid: [...]"* so the LLM learns the correct name without trial- + and-error. +- Missing required parameters produce *"Action X requires parameter Y."* + (no Python `TypeError` leaks). +- Vector dimensions and numeric dtype are validated before MuJoCo sees + them (previously zero-length direction vectors crashed the Python + process via `mj_ray` C-level abort). +- `destroy()` and `cleanup()` empty the renderer TLS cache and shut down + the executor - no RSS growth across repeated create/destroy cycles. + +For the full action contract and test coverage see +`tests/simulation/mujoco/test_agenttool_contract.py`. + ## Contributing We welcome contributions! Please see: diff --git a/pyproject.toml b/pyproject.toml index f1a7090..2c2e0fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,13 +48,16 @@ groot-service = [ lerobot = [ "lerobot>=0.5.0,<0.6.0", ] -sim = [ +sim-mujoco = [ "robot_descriptions>=1.11.0,<2.0.0", + "mujoco>=3.2.0,<4.0.0", + "imageio>=2.28.0,<3.0.0", + "imageio-ffmpeg>=0.4.0,<1.0.0", ] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", - "strands-robots[sim]", + "strands-robots[sim-mujoco]", ] dev = [ "pytest>=6.0,<9.0.0", @@ -79,6 +82,7 @@ packages = ["strands_robots"] [tool.hatch.envs.default] installer = "uv" +features = ["all"] dependencies = [ "pytest>=6.0,<9.0.0", "pytest-cov>=4.0.0,<6.0.0", @@ -128,7 +132,7 @@ ignore_missing_imports = false # Third-party libs without type stubs [[tool.mypy.overrides]] -module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*"] +module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*", "imageio.*"] ignore_missing_imports = true # @tool decorator injects runtime signatures mypy cannot check @@ -161,6 +165,22 @@ module = ["strands_robots.registry.*"] warn_return_any = false disallow_untyped_defs = false +# MuJoCo simulation — mixins use cooperative self._world patterns +# attr-defined: Mixins access self._world/self._lock/etc. from Simulation (cooperative pattern) +# assignment: PEP 484 implicit Optional (= None on typed params) +# override: Subclass signatures extend base with extra params (orientation, mesh_path) +# misc: Multiple inheritance method resolution conflicts between mixin + ABC +[[tool.mypy.overrides]] +module = ["strands_robots.simulation.mujoco.*"] +disallow_untyped_defs = false +warn_return_any = false + +# Async utils and dataset recorder — thin wrappers with dynamic types +[[tool.mypy.overrides]] +module = ["strands_robots._async_utils", "strands_robots.dataset_recorder"] +disallow_untyped_defs = false +warn_return_any = false + # Test files — relaxed type checking for mocks, fixtures, and test utilities [[tool.mypy.overrides]] module = ["tests.*", "tests_integ.*"] diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py index 8ee9c41..7f5a54f 100644 --- a/strands_robots/__init__.py +++ b/strands_robots/__init__.py @@ -25,14 +25,11 @@ import warnings as _warnings from typing import Any -# ------------------------------------------------------------------ -# Light-weight imports — no torch / lerobot dependency -# ------------------------------------------------------------------ +# Light-weight imports - no torch / lerobot dependency from strands_robots.policies import MockPolicy, Policy, create_policy # noqa: F401 -# ------------------------------------------------------------------ # Lazy-loaded heavy symbols -# ------------------------------------------------------------------ + # Maps public name -> (module_path, attribute_name) _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "Robot": ("strands_robots.robot", "Robot"), diff --git a/strands_robots/_async_utils.py b/strands_robots/_async_utils.py new file mode 100644 index 0000000..478518b --- /dev/null +++ b/strands_robots/_async_utils.py @@ -0,0 +1,31 @@ +"""Async-to-sync helper for resolving coroutines in sync contexts.""" + +import asyncio +import concurrent.futures + +# Module-level executor reused across calls to avoid creating threads at high frequency. +# A single worker is sufficient - we only need to offload one asyncio.run() at a time. +_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="strands_async") + + +def _resolve_coroutine(coro_or_result): # type: ignore[no-untyped-def] + """Safely resolve a potentially-async result to a sync value. + + Handles three cases: + 1. Already a plain value → return as-is + 2. Coroutine, no running loop → asyncio.run() + 3. Coroutine, inside running loop → offload to reused thread + + Args: + coro_or_result: Either a coroutine or an already-resolved value. + + Returns: + The resolved (sync) value. + """ + if not asyncio.iscoroutine(coro_or_result): + return coro_or_result + try: + asyncio.get_running_loop() + return _EXECUTOR.submit(asyncio.run, coro_or_result).result() + except RuntimeError: + return asyncio.run(coro_or_result) diff --git a/strands_robots/assets/__init__.py b/strands_robots/assets/__init__.py index 4e9080c..a6faea5 100644 --- a/strands_robots/assets/__init__.py +++ b/strands_robots/assets/__init__.py @@ -4,7 +4,7 @@ MuJoCo Menagerie GitHub, cached in ``~/.strands_robots/assets/``. Override with ``STRANDS_ASSETS_DIR`` env var. -Implementation lives in ``assets/manager.py`` — this file is thin exports only. +Implementation lives in ``assets/manager.py`` - this file is thin exports only. """ from strands_robots.assets.manager import ( diff --git a/strands_robots/assets/download.py b/strands_robots/assets/download.py index a6ffa7d..93612d7 100644 --- a/strands_robots/assets/download.py +++ b/strands_robots/assets/download.py @@ -5,7 +5,7 @@ that delegates to :func:`download_robots` here. Strategy (in order of preference): - 1. ``robot_descriptions`` package — recommended by MuJoCo Menagerie. + 1. ``robot_descriptions`` package - recommended by MuJoCo Menagerie. 2. Shallow ``git clone`` fallback for Menagerie robots. 3. Custom GitHub repos for non-Menagerie robots. @@ -40,7 +40,7 @@ _ALLOWED_CLONE_URL_RE = re.compile(r"^https://github\.com/[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+\.git$") -# ── robot_descriptions integration ──────────────────────────────────── +# robot_descriptions integration def _robot_descriptions_available() -> bool: @@ -102,10 +102,7 @@ def _resolve_robot_descriptions_module(name: str, info: dict) -> str | None: return None -# ── Helpers ─────────────────────────────────────────────────────────── - - -#: Alias for backward compatibility — use :func:`strands_robots.utils.get_assets_dir`. +#: Alias for backward compatibility - use :func:`strands_robots.utils.get_assets_dir`. get_user_assets_dir = get_assets_dir @@ -151,7 +148,7 @@ def _get_source(info: dict[str, Any] | None) -> dict[str, Any]: def _shallow_clone(repo_url: str, dest: str, *, timeout: int = 120) -> None: """Shallow-clone *repo_url* into *dest*. - Only HTTPS ``github.com`` URLs are accepted — ``ssh://``, ``git://``, + Only HTTPS ``github.com`` URLs are accepted - ``ssh://``, ``git://``, ``file://``, and other schemes are rejected to prevent command-injection and SSRF risks. @@ -174,7 +171,7 @@ def _shallow_clone(repo_url: str, dest: str, *, timeout: int = 120) -> None: # Filenames/patterns that are safe to strip from an upstream source tree before # we copy it into the user's asset cache. Filtering at *copy* time (rather than # deleting afterwards) means we never touch files that may already exist in *dst* -# — which matters when the user keeps notes/README alongside assets. +# - which matters when the user keeps notes/README alongside assets. _COPY_CLEAN_SKIP = frozenset({"README.md", "LICENSE", "CHANGELOG.md"}) _COPY_CLEAN_SUFFIX = (".png", ".jpg", ".jpeg") @@ -195,9 +192,6 @@ def _ignore(_dir: str, names: list[str]) -> list[str]: shutil.copytree(str(src), str(dst), dirs_exist_ok=True, ignore=_ignore) -# ── Download backends ───────────────────────────────────────────────── - - def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> dict[str, str]: """Download robots using the ``robot_descriptions`` package. @@ -234,9 +228,9 @@ def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> if expected_xml.exists(): results[name] = "downloaded" continue - # Stale symlink — remove and re-download via git + # Stale symlink - remove and re-download via git dst.unlink() - results[name] = f"failed: stale symlink — {info['asset']['model_xml']} not found in {package_path}" + results[name] = f"failed: stale symlink - {info['asset']['model_xml']} not found in {package_path}" continue if dst.exists() or dst.is_symlink(): dst.unlink() if dst.is_symlink() else shutil.rmtree(str(dst)) @@ -251,7 +245,7 @@ def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> if not expected_xml.exists(): logger.warning( "robot_descriptions module '%s' linked for %s but " - "expected XML '%s' not found — falling back to git", + "expected XML '%s' not found - falling back to git", module_name, name, info["asset"]["model_xml"], @@ -261,7 +255,7 @@ def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> else: shutil.rmtree(str(dst), ignore_errors=True) results[name] = ( - f"failed: XML mismatch — module '{module_name}' does not contain {info['asset']['model_xml']}" + f"failed: XML mismatch - module '{module_name}' does not contain {info['asset']['model_xml']}" ) continue @@ -333,7 +327,7 @@ def _download_from_github(name: str, info: dict, dest_dir: Path) -> str: return f"failed: {exc}" -# ── Orchestrator ────────────────────────────────────────────────────── +# Orchestrator def auto_download_robot(name: str, info: dict[str, Any]) -> bool: @@ -379,20 +373,20 @@ def download_robots( """Download robot model assets from their respective sources. Strategy (in order of preference): - 1. ``robot_descriptions`` package — recommended by MuJoCo Menagerie. + 1. ``robot_descriptions`` package - recommended by MuJoCo Menagerie. 2. Shallow ``git clone`` fallback for Menagerie robots. 3. Custom GitHub repos for non-Menagerie robots. Args: names: Robot names to download (``None`` = all sim robots). - category: Filter by category (arm, humanoid, mobile, …). + category: Filter by category (arm, humanoid, mobile, ...). force: Re-download even if present. Returns: Dict with downloaded/skipped/failed counts, names, and details. """ dest_dir = get_user_assets_dir() - # Filter None values — get_robot() can return None for unknown names + # Filter None values - get_robot() can return None for unknown names all_sim: dict[str, dict[str, Any]] = { r["name"]: info for r in registry_list_robots(mode="sim") if (info := get_robot(r["name"])) is not None } diff --git a/strands_robots/assets/manager.py b/strands_robots/assets/manager.py index ca610a6..34f37ce 100644 --- a/strands_robots/assets/manager.py +++ b/strands_robots/assets/manager.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -# Module-level conditional import — keeps manager.py importable in +# Module-level conditional import - keeps manager.py importable in # environments where the optional ``robot_descriptions`` package (and its # transitive heavyweight deps like ``GitPython``) are not installed. # When ``download`` is not available, auto-download simply returns False. @@ -32,9 +32,9 @@ _auto_download_robot_impl = None # type: ignore[assignment] -# ───────────────────────────────────────────────────────────────────── +# # Model path resolution (delegates to registry) -# ───────────────────────────────────────────────────────────────────── +# def _auto_download_robot(name: str, info: dict) -> bool: @@ -114,7 +114,7 @@ def _resolve_candidates(asset_dir_name: str, xml_file: str, name: str) -> list[P def is_robot_asset_present(name: str) -> bool: """Check whether a robot's model XML exists on disk without triggering downloads. - Pure filesystem check — no auto-download, no mesh walk, no network. + Pure filesystem check - no auto-download, no mesh walk, no network. Use this for status queries (e.g. ``download_assets(action="status")``) where you need to quickly check presence without side effects. @@ -194,7 +194,7 @@ def resolve_model_path( # Check user-registered asset path first (highest priority). # ``xml_file`` comes from user_robots.json, so we still gate it through # :func:`safe_join` to block path traversal even for user-authored entries - # (defense in depth — protects against a compromised user_robots.json and + # (defense in depth - protects against a compromised user_robots.json and # keeps the trust boundary identical to the built-in registry path). user_path = info.get("_user_asset_path") if user_path: @@ -214,7 +214,7 @@ def resolve_model_path( candidates.extend(_resolve_candidates(asset_dir_name, xml_file, name)) if not candidates: - # No XML found at all — try auto-download, then re-search + # No XML found at all - try auto-download, then re-search logger.info("No XML found for %s, attempting auto-download...", name) if _auto_download_robot(name, info): candidates.extend(_resolve_candidates(asset_dir_name, xml_file, name)) @@ -230,7 +230,7 @@ def resolve_model_path( logger.debug("Resolved %s → %s (has meshes)", name, path) return Path(path) - # XML found but no meshes — auto-download and re-check + # XML found but no meshes - auto-download and re-check logger.info("XML found for %s but no meshes, attempting auto-download...", name) if _auto_download_robot(name, info): # Re-scan after download (new symlinks may have appeared) @@ -305,7 +305,7 @@ def list_available_robots() -> list[dict]: name = r["name"] present = is_robot_asset_present(name) info = get_robot(name) or {} - # Only resolve full path when asset is present — avoids download attempts + # Only resolve full path when asset is present - avoids download attempts path = resolve_model_path(name) if present else None robots.append( { diff --git a/strands_robots/dataset_recorder.py b/strands_robots/dataset_recorder.py new file mode 100644 index 0000000..d38f01b --- /dev/null +++ b/strands_robots/dataset_recorder.py @@ -0,0 +1,515 @@ +"""LeRobotDataset recorder bridge for strands-robots. + +Wraps LeRobotDataset so that both robot.py (real hardware) and +simulation.py (MuJoCo) can produce training-ready datasets with +a single add_frame() call per control step. + +Usage: + recorder = DatasetRecorder.create( + repo_id="user/my_dataset", + fps=30, + robot_features=robot.observation_features, + action_features=robot.action_features, + task="pick up the red cube", + ) + # In control loop: + recorder.add_frame(observation, action, task="pick up the red cube") + # End of episode: + recorder.save_episode() + # Optionally: + recorder.push_to_hub() +""" + +import functools +import logging +import sys +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + +# Lazy check for LeRobot availability +# We must NOT import lerobot at module level because it pulls in +# `datasets` → `pandas`, which can crash with a numpy ABI mismatch on +# systems where the system pandas was compiled against an older numpy +# (e.g. JetPack / Jetson with system pandas 2.1.4 + pip numpy 2.x). + + +@functools.lru_cache(maxsize=1) +def has_lerobot_dataset() -> bool: + """Check if lerobot is available. Result is cached after first call.""" + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: F401 + + return True + except (ImportError, ValueError, RuntimeError) as exc: + logger.debug("lerobot not available: %s", exc) + return False + + +def _get_lerobot_dataset_class(): + """Import and return LeRobotDataset class, or raise ImportError. + + Supports test mocking: if ``strands_robots.dataset_recorder.LeRobotDataset`` + has been set (by a test mock), returns that class directly. + """ + # Support test mocking: check module-level overrides + this_module = sys.modules[__name__] + + # If a test injected a mock LeRobotDataset class, use it + mock_cls = getattr(this_module, "LeRobotDataset", None) + if mock_cls is not None: + return mock_cls + + # Actual import + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + return LeRobotDataset + except (ImportError, ValueError, RuntimeError) as exc: + raise ImportError( + f"lerobot not available ({exc}). Install with: pip install lerobot\nRequired for LeRobotDataset recording." + ) from exc + + +class DatasetRecorder: + """Bridge between strands-robots control loops and LeRobotDataset. + + Handles the full lifecycle: + 1. create() - build LeRobotDataset with correct features + 2. add_frame() - called every control step with obs + action + 3. save_episode() - finalize episode (encodes video, writes parquet) + 4. push_to_hub() - upload to HuggingFace + + Works for both real hardware (robot.py) and simulation (simulation.py). + """ + + def __init__(self, dataset, task: str = "", strict: bool = True): + self.dataset = dataset + self.default_task = task + self.frame_count = 0 + self.dropped_frame_count = 0 + self.strict = strict + self.episode_count = 0 + self._closed = False + self._cached_state_keys: list[str] | None = None + self._cached_action_keys: list[str] | None = None + + @classmethod + def create( + cls, + repo_id: str, + fps: int = 30, + robot_type: str = "unknown", + robot_features: dict[str, Any] | None = None, + action_features: dict[str, Any] | None = None, + camera_keys: list[str] | None = None, + joint_names: list[str] | None = None, + task: str = "", + root: str | None = None, + use_videos: bool = True, + vcodec: str = "libsvtav1", + streaming_encoding: bool = True, + image_writer_threads: int = 4, + video_backend: str = "auto", + ) -> "DatasetRecorder": + """Create a new DatasetRecorder with auto-detected features. + + Args: + repo_id: HuggingFace dataset ID (e.g. "user/my_dataset") + fps: Recording frame rate + robot_type: Robot type string (e.g. "so100", "panda") + robot_features: Dict of observation feature names → types + (from robot.observation_features or sim joint names) + action_features: Dict of action feature names → types + camera_keys: List of camera names (images become video features) + joint_names: List of joint names (alternative to robot_features for sim) + task: Default task description + root: Local directory for dataset storage + use_videos: Encode camera frames as video (True) or keep as images + vcodec: Video codec (h264, hevc, libsvtav1) + streaming_encoding: Stream-encode video during capture + image_writer_threads: Threads for writing image frames + video_backend: Video backend for encoding ("auto" for HW encoder auto-detect) + """ + # Lazy import - this is where we actually need lerobot + LeRobotDatasetCls = _get_lerobot_dataset_class() + + # Build features dict in LeRobot format + features = cls._build_features( + robot_features=robot_features, + action_features=action_features, + camera_keys=camera_keys, + joint_names=joint_names, + use_videos=use_videos, + ) + + logger.info(f"Creating LeRobotDataset: {repo_id} @ {fps}fps, {len(features)} features, robot_type={robot_type}") + + # Build kwargs, skip unsupported params for this LeRobot version + create_kwargs = dict( + repo_id=repo_id, + fps=fps, + root=root, + robot_type=robot_type, + features=features, + use_videos=use_videos, + image_writer_threads=image_writer_threads, + vcodec=vcodec, + ) + # streaming_encoding only in newer LeRobot versions + import inspect + + create_sig = inspect.signature(LeRobotDatasetCls.create) + if "streaming_encoding" in create_sig.parameters: + create_kwargs["streaming_encoding"] = streaming_encoding + if "video_backend" in create_sig.parameters: + create_kwargs["video_backend"] = video_backend + dataset = LeRobotDatasetCls.create(**create_kwargs) + + recorder = cls(dataset=dataset, task=task) + logger.info("DatasetRecorder ready: %s", repo_id) + return recorder + + @classmethod + def _build_features( + cls, + robot_features: dict | None = None, + action_features: dict | None = None, + camera_keys: list[str] | None = None, + joint_names: list[str] | None = None, + use_videos: bool = True, + ) -> dict[str, Any]: + """Build LeRobot v3-compatible features dict. + + LeRobot v3 features format: + { + "observation.images.camera_name": {"dtype": "video", "shape": (C, H, W), "names": [...]}, + "observation.state": {"dtype": "float32", "shape": (N,), "names": [...]}, + "action": {"dtype": "float32", "shape": (N,), "names": [...]}, + } + + Note: "names" must be a flat list of strings, NOT a dict like {"motors": [...]}. + """ + features = {} + + # Observation: cameras → video/image features + if camera_keys: + for cam_name in camera_keys: + key = f"observation.images.{cam_name}" + dtype = "video" if use_videos else "image" + features[key] = { + "dtype": dtype, + "shape": ( + 3, + 480, + 640, + ), # CHW default, actual shape set on first frame + "names": ["channels", "height", "width"], + } + + # Observation: state (joint positions) + state_dim = 0 + state_names = [] + if robot_features: + # Count scalar features (exclude cameras) + state_keys = [ + k + for k, v in robot_features.items() + if not isinstance(v, dict) or v.get("dtype") not in ("image", "video") + ] + state_dim = len(state_keys) + state_names = state_keys + elif joint_names: + state_dim = len(joint_names) + state_names = list(joint_names) + + if state_dim > 0: + features["observation.state"] = { + "dtype": "float32", + "shape": (state_dim,), + "names": state_names, + } + + # Action + action_dim = 0 + action_names = [] + if action_features: + action_keys = [ + k + for k, v in action_features.items() + if not isinstance(v, dict) or v.get("dtype") not in ("image", "video") + ] + action_dim = len(action_keys) + action_names = action_keys + elif joint_names: + action_dim = len(joint_names) + action_names = list(joint_names) + elif state_dim > 0: + action_dim = state_dim # Same dim as state by default + action_names = state_names[:] + + if action_dim > 0: + features["action"] = { + "dtype": "float32", + "shape": (action_dim,), + "names": action_names[:action_dim], + } + + return features + + def add_frame( + self, + observation: dict[str, Any], + action: dict[str, Any], + task: str | None = None, + camera_keys: list[str] | None = None, + ) -> None: + """Add a single control-loop frame to the dataset. + + This is the key method - called every step in the control loop. + + Args: + observation: Raw observation dict from robot/sim + (joint_name → float, camera_name → np.ndarray) + action: Action dict (joint_name → float) + task: Task description (uses default if None) + camera_keys: Which keys in observation are camera images + """ + if self._closed: + return + + frame = {} + + # Detect camera vs state keys + if camera_keys is None: + camera_keys = [k for k, v in observation.items() if isinstance(v, np.ndarray) and v.ndim >= 2] + + state_keys = [k for k in observation.keys() if k not in camera_keys] + + # Camera images → observation.images.{name} + for cam_key in camera_keys: + img = observation[cam_key] + if isinstance(img, np.ndarray): + # LeRobot expects HWC uint8 for add_frame + if img.dtype != np.uint8: + img = (np.clip(img, 0, 1) * 255).astype(np.uint8) + frame[f"observation.images.{cam_key}"] = img + + # State → observation.state (flattened vector) + # Use feature schema ordering to match the dataset schema declared in _build_features(). + if state_keys: + state_vals = [] + if self._cached_state_keys is None: + feat = self.dataset.features.get("observation.state", {}) + state_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", []) + self._cached_state_keys = state_names if state_names else sorted(state_keys) + + for k in self._cached_state_keys: + v = observation.get(k) + if v is None: + state_vals.append(0.0) + elif isinstance(v, (int, float)): + state_vals.append(float(v)) + elif isinstance(v, np.ndarray) and v.ndim == 0: + state_vals.append(float(v)) + elif isinstance(v, (list, np.ndarray)): + arr = np.asarray(v, dtype=np.float32).flatten() + state_vals.extend(arr.tolist()) + if state_vals: + frame["observation.state"] = np.array(state_vals, dtype=np.float32) + + # Action → flattened vector + # Use feature schema ordering for actions too. + if action: + action_vals = [] + if self._cached_action_keys is None: + feat = self.dataset.features.get("action", {}) + action_names = feat.get("names", []) if isinstance(feat, dict) else getattr(feat, "names", []) + self._cached_action_keys = action_names if action_names else sorted(action.keys()) + + for k in self._cached_action_keys: + v = action.get(k) + if v is None: + action_vals.append(0.0) + elif isinstance(v, (int, float)): + action_vals.append(float(v)) + elif isinstance(v, np.ndarray) and v.ndim == 0: + action_vals.append(float(v)) + elif isinstance(v, (list, np.ndarray)): + arr = np.asarray(v, dtype=np.float32).flatten() + action_vals.extend(arr.tolist()) + if action_vals: + frame["action"] = np.array(action_vals, dtype=np.float32) + + # Task (mandatory for LeRobot v3) + frame["task"] = task or self.default_task or "untitled" + + # Reconcile camera keys between frame and feature schema + # Normalize namespaced camera keys (e.g. "arm0/wrist_cam" → "arm0__wrist_cam") + # to match the schema declared in _build_features. MuJoCo uses "/" as a + # namespace separator for multi-robot cameras, but LeRobot feature names + # cannot contain "/" (reserved for nested-feature addressing). + declared_cam_keys = {k for k in self.dataset.features if k.startswith("observation.images.")} + frame_cam_keys = {k for k in list(frame.keys()) if k.startswith("observation.images.")} + for cam_key in frame_cam_keys: + normalized = cam_key.replace("/", "__") + if normalized != cam_key and normalized in declared_cam_keys: + frame[normalized] = frame.pop(cam_key) + + # Strip undeclared cameras (keys present in obs but not registered in + # _build_features). This avoids LeRobot's "Extra features" error. + # Declared-but-missing cameras (e.g. when a render fails) are left alone - + # LeRobot tolerates absent columns and the episode simply won't have that + # camera's data. + frame_cam_keys_final = {k for k in frame if k.startswith("observation.images.")} + for extra in frame_cam_keys_final - declared_cam_keys: + del frame[extra] + + # Add to dataset + try: + self.dataset.add_frame(frame) + self.frame_count += 1 + except Exception as e: + if self.strict: + raise # Fail-fast per AGENTS.md convention #5 + self.dropped_frame_count += 1 + n = self.dropped_frame_count + # Log at 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, then every 1000 + if (n & (n - 1)) == 0 or n % 1000 == 0: + logger.warning( + "add_frame failed (frame %d, dropped %d): %s", + self.frame_count, + self.dropped_frame_count, + e, + ) + + def save_episode(self) -> dict[str, Any]: + """Finalize current episode - writes parquet, encodes video, computes stats. + + LeRobot v3: save_episode() takes no task argument. Tasks are stored + per-frame in the episode buffer via add_frame(). + + Returns: + Dict with episode info + """ + if self._closed: + return {"status": "error", "message": "Recorder closed"} + + try: + self.dataset.save_episode() + self.episode_count += 1 + ep_frames = self.frame_count # Total frames so far + logger.info(f"Episode {self.episode_count} saved: {ep_frames} total frames") + return { + "status": "success", + "episode": self.episode_count, + "total_frames": ep_frames, + } + except Exception as e: + logger.error("save_episode failed: %s", e) + return {"status": "error", "message": str(e)} + + def finalize(self) -> None: + """Finalize the dataset (close parquet writers, flush metadata).""" + if self._closed: + return + try: + self.dataset.finalize() + except Exception as e: + logger.warning("finalize warning: %s", e) + self._closed = True + + def push_to_hub( + self, + tags: list[str] | None = None, + private: bool = False, + ) -> dict[str, Any]: + """Push dataset to HuggingFace Hub. + + Args: + tags: Optional tags for the dataset + private: Upload as private dataset + + Returns: + Dict with push status + """ + try: + self.dataset.push_to_hub(tags=tags, private=private) + logger.info("Dataset pushed to hub: %s", self.dataset.repo_id) + return { + "status": "success", + "repo_id": self.dataset.repo_id, + "episodes": self.episode_count, + "frames": self.frame_count, + } + except Exception as e: + logger.error("push_to_hub failed: %s", e) + return {"status": "error", "message": str(e)} + + @property + def repo_id(self) -> str: + return self.dataset.repo_id + + @property + def root(self) -> str: + return str(self.dataset.root) + + def __repr__(self) -> str: + return f"DatasetRecorder(repo_id={self.repo_id}, episodes={self.episode_count}, frames={self.frame_count})" + + +# Shared replay-episode helpers + + +def load_lerobot_episode(repo_id: str, episode: int = 0, root: str | None = None): + """Load a LeRobotDataset and resolve the frame range for an episode. + + Returns: + Tuple of (dataset, episode_start, episode_length) on success. + + Raises: + ImportError: If lerobot is not installed. + ValueError: If the episode is out of range or has no frames. + """ + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + ds = LeRobotDataset(repo_id=repo_id, root=root) + + num_episodes = ds.meta.total_episodes if hasattr(ds.meta, "total_episodes") else len(ds.meta.episodes) + if episode >= num_episodes: + raise ValueError(f"Episode {episode} out of range (0-{num_episodes - 1})") + + episode_start = 0 + episode_length = 0 + try: + if hasattr(ds, "episode_data_index"): + from_idx = ds.episode_data_index["from"][episode].item() + to_idx = ds.episode_data_index["to"][episode].item() + episode_start = from_idx + episode_length = to_idx - from_idx + else: + for i in range(episode): + ep_info = ds.meta.episodes[i] if hasattr(ds.meta, "episodes") else {} + episode_start += ep_info.get("length", 0) + ep_info = ds.meta.episodes[episode] if hasattr(ds.meta, "episodes") else {} + episode_length = ep_info.get("length", 0) + except Exception: + # Last resort: scan frames to find episode boundaries + for idx in range(len(ds)): + frame = ds[idx] + frame_ep = frame.get("episode_index", -1) if hasattr(frame, "get") else -1 + if hasattr(frame_ep, "item"): + frame_ep = frame_ep.item() + if frame_ep == episode: + if episode_length == 0: + episode_start = idx + episode_length += 1 + elif episode_length > 0: + break + + if episode_length == 0: + raise ValueError(f"Episode {episode} has no frames") + + return ds, episode_start, episode_length diff --git a/strands_robots/policies/__init__.py b/strands_robots/policies/__init__.py index 6cffe18..048e268 100644 --- a/strands_robots/policies/__init__.py +++ b/strands_robots/policies/__init__.py @@ -1,6 +1,6 @@ """Policy Abstraction for Universal VLA Support. -Plugin-based registry — all provider definitions live in registry/policies.json. +Plugin-based registry - all provider definitions live in registry/policies.json. No hardcoded if/elif chains. New providers are auto-discovered or registered at runtime. Built-in providers (see policies.json for full list): diff --git a/strands_robots/policies/base.py b/strands_robots/policies/base.py index a082f4c..c2e5ed6 100644 --- a/strands_robots/policies/base.py +++ b/strands_robots/policies/base.py @@ -54,6 +54,18 @@ def set_robot_state_keys(self, robot_state_keys: list[str]) -> None: """Configure the policy with robot state keys.""" pass + @property + def requires_images(self) -> bool: + """Whether this policy needs camera frames in its observation. + + Default True (most VLA policies do). Subclasses that only consume + joint state (e.g. ``MockPolicy``, pure-IK controllers, scripted + trajectories) can return ``False`` to let the simulation skip + expensive camera rendering - a ~10x throughput win at 500Hz when + no cameras are needed. + """ + return True + @property @abstractmethod def provider_name(self) -> str: diff --git a/strands_robots/policies/factory.py b/strands_robots/policies/factory.py index 062978d..e25c842 100644 --- a/strands_robots/policies/factory.py +++ b/strands_robots/policies/factory.py @@ -1,4 +1,4 @@ -"""Policy factory — create_policy() and runtime registration.""" +"""Policy factory - create_policy() and runtime registration.""" import logging import os @@ -9,9 +9,9 @@ logger = logging.getLogger(__name__) -# ───────────────────────────────────────────────────────────────────── +# # Runtime registration (for user-defined providers not in JSON) -# ───────────────────────────────────────────────────────────────────── +# _runtime_registry: dict[str, Callable[[], type[Policy]]] = {} _runtime_aliases: dict[str, str] = {} diff --git a/strands_robots/policies/groot/__init__.py b/strands_robots/policies/groot/__init__.py index db09efe..8884c0c 100644 --- a/strands_robots/policies/groot/__init__.py +++ b/strands_robots/policies/groot/__init__.py @@ -1,4 +1,4 @@ -"""GR00T Policy — NVIDIA GR00T N1.5 and N1.6 support. +"""GR00T Policy - NVIDIA GR00T N1.5 and N1.6 support. Two inference modes: diff --git a/strands_robots/policies/groot/client.py b/strands_robots/policies/groot/client.py index e8cf0e6..5bf4257 100644 --- a/strands_robots/policies/groot/client.py +++ b/strands_robots/policies/groot/client.py @@ -1,4 +1,4 @@ -"""GR00T inference client — ZMQ client for inference-service communication. +"""GR00T inference client - ZMQ client for inference-service communication. Handles serialization of numpy arrays and ModalityConfig objects over ZMQ using msgpack with custom encode/decode hooks. @@ -138,7 +138,7 @@ def ping(self) -> bool: """Check server connectivity. Returns True if the server responds, False otherwise. - Does NOT auto-reconnect — call :meth:`reconnect` explicitly if needed. + Does NOT auto-reconnect - call :meth:`reconnect` explicitly if needed. """ try: self.call_endpoint("ping") @@ -185,7 +185,7 @@ def get_action(self, observations: dict[str, Any]) -> dict[str, Any]: is currently empty in all upstream embodiments. """ response = self.call_endpoint("get_action", {"observation": observations, "options": None}) - # N1.6/N1.7 servers return a (action_dict, info_dict) tuple—msgpack + # N1.6/N1.7 servers return a (action_dict, info_dict) tuple - msgpack # decodes tuples as lists, so we may see either shape here. if isinstance(response, list | tuple) and len(response) == 2: action, _info = response diff --git a/strands_robots/policies/groot/data_config.py b/strands_robots/policies/groot/data_config.py index e5fc879..6dc30d3 100644 --- a/strands_robots/policies/groot/data_config.py +++ b/strands_robots/policies/groot/data_config.py @@ -1,4 +1,4 @@ -"""GR00T data configuration — typed embodiment key mappings. +"""GR00T data configuration - typed embodiment key mappings. Provides :class:`Gr00tDataConfig` dataclasses and an ``_extends`` inheritance mechanism so new robot configs can be defined by overriding only what differs @@ -59,9 +59,7 @@ def modality_config(self) -> dict[str, ModalityConfig]: } -# --------------------------------------------------------------------------- # Config resolution with _extends inheritance -# --------------------------------------------------------------------------- def _resolve_config(name: str, definitions: dict) -> Gr00tDataConfig: @@ -88,9 +86,7 @@ def _resolve_config(name: str, definitions: dict) -> Gr00tDataConfig: return Gr00tDataConfig(**merged) -# --------------------------------------------------------------------------- # Load configs from JSON -# --------------------------------------------------------------------------- _CONFIG_FILE = Path(__file__).parent / "data_configs.json" diff --git a/strands_robots/policies/groot/policy.py b/strands_robots/policies/groot/policy.py index 162a80f..8d23bae 100644 --- a/strands_robots/policies/groot/policy.py +++ b/strands_robots/policies/groot/policy.py @@ -1,4 +1,4 @@ -"""GR00T policy — N1.5/N1.6 service and local inference. +"""GR00T policy - N1.5/N1.6 service and local inference. Implements :class:`~strands_robots.policies.base.Policy` for NVIDIA GR00T models. @@ -33,9 +33,7 @@ logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- # Isaac-GR00T version detection -# --------------------------------------------------------------------------- _GROOT_VERSION: str | None = None # "n1.5", "n1.6", "n1.7", or None @@ -61,7 +59,7 @@ def _detect_groot_version(*, force: bool = False) -> str | None: # Reset before re-detection _GROOT_VERSION = None - # N1.7 first — the new Cosmos-Reason2-2B backbone lives here. + # N1.7 first - the new Cosmos-Reason2-2B backbone lives here. # Detecting by subpackage (not enum values) keeps the probe cheap. try: if importlib.util.find_spec("gr00t.model.gr00t_n1d7") is not None: @@ -90,9 +88,7 @@ def _detect_groot_version(*, force: bool = False) -> str | None: return None -# --------------------------------------------------------------------------- # Mapping dataclasses -# --------------------------------------------------------------------------- @dataclass(frozen=True) @@ -100,8 +96,8 @@ class ObservationMapping: """Maps robot sensor names → model modality keys. Attributes: - video: ``{robot_camera: model_video_key}`` — bare, no prefix. - state: ``{robot_state: model_state_key}`` — bare, no prefix. + video: ``{robot_camera: model_video_key}`` - bare, no prefix. + state: ``{robot_state: model_state_key}`` - bare, no prefix. language_key: Model's language key (e.g. ``"task"``). """ @@ -139,7 +135,7 @@ class ActionMapping: """Maps model action keys → robot actuator names. Attributes: - actions: ``{model_action_key: robot_actuator}`` — bare, no prefix. + actions: ``{model_action_key: robot_actuator}`` - bare, no prefix. """ actions: dict[str, str] = field(default_factory=dict) @@ -152,9 +148,7 @@ def validate(self, modality_configs: dict) -> None: raise ValueError(f"Action mapping: model key '{model_key}' not in model: {sorted(model_action)}") -# --------------------------------------------------------------------------- # Auto-inference (exact name match → positional fallback) -# --------------------------------------------------------------------------- def _auto_infer_observation_mapping( @@ -214,9 +208,7 @@ def _match_keys(ours: list[str], model: list[str], label: str) -> dict[str, str] return mapping -# --------------------------------------------------------------------------- # Parse user-provided flat mapping dicts -# --------------------------------------------------------------------------- def _parse_observation_mapping( @@ -247,13 +239,11 @@ def _parse_action_mapping(flat: dict[str, str]) -> ActionMapping: return ActionMapping(actions={k.removeprefix("action."): v for k, v in flat.items()}) -# --------------------------------------------------------------------------- # Gr00tPolicy -# --------------------------------------------------------------------------- class Gr00tPolicy(Policy): - """GR00T policy — service mode and local inference (N1.5/N1.6). + """GR00T policy - service mode and local inference (N1.5/N1.6). For **local mode**, loads the model directly and talks its native nested-dict format. Robot↔model key translation is done by explicit mappings. @@ -317,7 +307,7 @@ def __init__( self._groot_version = groot_version or _detect_groot_version() self._strict = strict - # DOF per model state key — discovered from model at load time + # DOF per model state key - discovered from model at load time self._model_state_dof: dict[str, int] = {} # Raw user mappings (parsed after model load) @@ -348,9 +338,7 @@ def __init__( self.data_config_name, ) - # ------------------------------------------------------------------ # Mapping initialization - # ------------------------------------------------------------------ def _init_mappings(self) -> None: """Initialize observation/action mappings after model load.""" @@ -463,16 +451,14 @@ def _discover_model_state_dof(self, mmc: dict) -> None: missing = all_keys - discovered if missing: logger.warning( - "Could not discover DOF for state keys: %s — these will not be zero-filled if unmapped", + "Could not discover DOF for state keys: %s - these will not be zero-filled if unmapped", sorted(missing), ) if self._model_state_dof: logger.info("Model state DOF: %s", self._model_state_dof) - # ------------------------------------------------------------------ # Model loading - # ------------------------------------------------------------------ def _load_local_policy(self, model_path: str, embodiment_tag: str, device: str): if self._groot_version == "n1.7": @@ -504,7 +490,7 @@ def _load_n15(self, model_path: str, embodiment_tag: str, device: str): logger.info("GR00T N1.5 loaded from %s", model_path) def _load_n16(self, model_path: str, embodiment_tag: str, device: str): - """Load N1.6 — uses Gr00tPolicy directly (NOT SimPolicyWrapper).""" + """Load N1.6 - uses Gr00tPolicy directly (NOT SimPolicyWrapper).""" from gr00t.data.embodiment_tags import EmbodimentTag from gr00t.policy.gr00t_policy import Gr00tPolicy as N16Policy @@ -518,7 +504,7 @@ def _load_n16(self, model_path: str, embodiment_tag: str, device: str): logger.info("GR00T N1.6 loaded from %s (direct)", model_path) def _load_n17(self, model_path: str, embodiment_tag: str, device: str): - """Load N1.7 — identical entry point to N1.6 (same ``Gr00tPolicy`` signature). + """Load N1.7 - identical entry point to N1.6 (same ``Gr00tPolicy`` signature). The user-visible policy class is still ``gr00t.policy.gr00t_policy.Gr00tPolicy``; internally it pulls the new Cosmos-Reason2-2B / Qwen3-VL backbone via @@ -537,9 +523,7 @@ def _load_n17(self, model_path: str, embodiment_tag: str, device: str): ) logger.info("GR00T N1.7 loaded from %s (direct)", model_path) - # ------------------------------------------------------------------ # Policy interface - # ------------------------------------------------------------------ @property def provider_name(self) -> str: @@ -553,9 +537,7 @@ async def get_actions(self, observation_dict: dict[str, Any], instruction: str, return self._local_get_actions(observation_dict, instruction) return self._service_get_actions(observation_dict, instruction) - # ------------------------------------------------------------------ - # Local inference — talks model's native nested-dict format - # ------------------------------------------------------------------ + # Local inference - talks model's native nested-dict format def _local_get_actions(self, robot_obs: dict[str, Any], instruction: str) -> list[dict[str, Any]]: """Local: prepare nested obs → infer → unpack actions.""" @@ -589,7 +571,7 @@ def _prepare_observation(self, robot_obs: dict[str, Any], instruction: str) -> d assert self._obs_mapping is not None, "Observation mapping not initialized" - # ── Video ── + # Video mapped_video_keys = set(self._obs_mapping.video.keys()) for robot_key, model_key in self._obs_mapping.video.items(): if robot_key in robot_obs: @@ -603,7 +585,7 @@ def _prepare_observation(self, robot_obs: dict[str, Any], instruction: str) -> d ref = _reference_video_shape(robot_obs, mapped_video_keys) video_dict[model_key] = np.zeros((1, 1, *ref), dtype=np.uint8) - # ── State ── + # State for robot_key, model_key in self._obs_mapping.state.items(): if robot_key in robot_obs: state_dict[model_key] = _to_state_batch(robot_obs[robot_key]) @@ -619,11 +601,11 @@ def _prepare_observation(self, robot_obs: dict[str, Any], instruction: str) -> d state_dict[model_key] = np.zeros((1, 1, dof), dtype=np.float32) else: logger.debug( - "Skipping zero-fill for '%s' — DOF unknown", + "Skipping zero-fill for '%s' - DOF unknown", model_key, ) - # ── Language ── + # Language lang_key = self._obs_mapping.language_key language_dict = {lang_key: [[instruction]]} @@ -663,9 +645,7 @@ def _unpack_actions(self, raw_actions: dict) -> list[dict[str, Any]]: return actions - # ------------------------------------------------------------------ # Service inference - # ------------------------------------------------------------------ def _service_get_actions(self, robot_obs: dict[str, Any], instruction: str) -> list[dict[str, Any]]: """Service mode: build observation, call server, unpack.""" @@ -735,7 +715,7 @@ def _unpack_service_actions(self, action_chunk: dict) -> list[dict[str, Any]]: actions.append(step) return actions - # No mapping — return bare model keys + # No mapping - return bare model keys actions = [] for t in range(horizon): step = {} @@ -746,9 +726,7 @@ def _unpack_service_actions(self, action_chunk: dict) -> list[dict[str, Any]]: return actions -# --------------------------------------------------------------------------- -# Shape helpers — match Isaac-GR00T's expected formats exactly -# --------------------------------------------------------------------------- +# Shape helpers - match Isaac-GR00T's expected formats exactly def _to_video_batch(value: np.ndarray) -> np.ndarray: diff --git a/strands_robots/policies/lerobot_local/__init__.py b/strands_robots/policies/lerobot_local/__init__.py index 7f1c281..c75ce8b 100644 --- a/strands_robots/policies/lerobot_local/__init__.py +++ b/strands_robots/policies/lerobot_local/__init__.py @@ -1,4 +1,4 @@ -"""LeRobot Local Policy — Direct HuggingFace model inference (no server needed).""" +"""LeRobot Local Policy - Direct HuggingFace model inference (no server needed).""" from .policy import LerobotLocalPolicy diff --git a/strands_robots/policies/lerobot_local/policy.py b/strands_robots/policies/lerobot_local/policy.py index b8e5817..a9e9171 100644 --- a/strands_robots/policies/lerobot_local/policy.py +++ b/strands_robots/policies/lerobot_local/policy.py @@ -1,4 +1,4 @@ -"""LeRobot Local Policy — Direct HuggingFace model inference (no server needed). +"""LeRobot Local Policy - Direct HuggingFace model inference (no server needed). Uses LeRobot's own factory for auto-detection. Any model LeRobot supports, this policy supports. @@ -186,9 +186,7 @@ def set_robot_state_keys(self, robot_state_keys: list[str]) -> None: "Call set_robot_state_keys() with the robot's actual joint/motor names." ) - # ------------------------------------------------------------------ # Tokenizer resolution (VLA language token injection) - # ------------------------------------------------------------------ def _resolve_tokenizer(self) -> Any | None: """Resolve and cache the tokenizer for VLA language token injection. @@ -288,9 +286,7 @@ def _needs_language_tokens(self) -> bool: return False - # ------------------------------------------------------------------ # Model loading - # ------------------------------------------------------------------ def _load_model(self) -> None: """Load the LeRobot model from pretrained path. @@ -381,7 +377,7 @@ def _load_model(self) -> None: self._processor_bridge = None logger.debug("No processor configs found, using raw obs/action flow") except (FileNotFoundError, ValueError, ImportError) as exc: - # Processor bridge is optional — models work without it via raw obs/action flow. + # Processor bridge is optional - models work without it via raw obs/action flow. # Fail-fast only if the user explicitly requested processor overrides. if self.processor_overrides: raise RuntimeError( @@ -393,9 +389,7 @@ def _load_model(self) -> None: # Initialize RTC if supported by this policy self._init_rtc() - # ------------------------------------------------------------------ # Real-Time Chunking (RTC) support - # ------------------------------------------------------------------ def _init_rtc(self) -> None: """Initialize RTC if the loaded policy supports it. @@ -422,7 +416,7 @@ def _init_rtc(self) -> None: return # Auto-detect from model config. - # RTC requires rtc_config on the model — not just predict_action_chunk(). + # RTC requires rtc_config on the model - not just predict_action_chunk(). # In LeRobot 0.5+, predict_action_chunk() is a base class method that ALL # policies inherit (ACT, Diffusion, etc.), but only flow-matching policies # (Pi0, SmolVLA) have an rtc_config that parameterizes the denoiser for @@ -437,7 +431,7 @@ def _init_rtc(self) -> None: elif self._rtc_requested is True: if rtc_config is None: # User explicitly asked for RTC, but this policy has no rtc_config. - # This means it's not a flow-matching policy — warn and disable. + # This means it's not a flow-matching policy - warn and disable. logger.warning( "RTC requested but policy '%s' has no rtc_config. " "RTC is only supported by flow-matching policies (Pi0, SmolVLA). " @@ -509,7 +503,7 @@ def _predict_with_rtc(self, batch: dict[str, Any]) -> torch.Tensor: batch: Observation batch tensors ready for the policy. Returns: - Action tensor — first action(s) from the chunk, accounting for + Action tensor - first action(s) from the chunk, accounting for inference delay. """ inference_start = time.time() @@ -532,7 +526,7 @@ def _predict_with_rtc(self, batch: dict[str, Any]) -> torch.Tensor: if action_chunk.dim() == 3 and action_chunk.shape[0] == 1: action_chunk = action_chunk.squeeze(0) - # Estimate inference delay — how many steps were consumed while computing + # Estimate inference delay - how many steps were consumed while computing inference_delay = self._estimate_inference_delay() # Store leftover for next RTC call (unconsumed portion of this chunk) @@ -547,11 +541,11 @@ def _predict_with_rtc(self, batch: dict[str, Any]) -> torch.Tensor: else: self._rtc_prev_chunk = None - # Skip delay steps — they correspond to time spent during inference + # Skip delay steps - they correspond to time spent during inference usable_start = min(inference_delay, action_chunk.shape[0] - 1) usable_actions = action_chunk[usable_start:] - # Log RTC details at debug level — throttled to once every 2s regardless of Hz + # Log RTC details at debug level - throttled to once every 2s regardless of Hz _now = time.monotonic() if _now - self._rtc_last_log_time >= 2.0: self._rtc_last_log_time = _now @@ -566,9 +560,7 @@ def _predict_with_rtc(self, batch: dict[str, Any]) -> torch.Tensor: return usable_actions - # ------------------------------------------------------------------ # Inference - # ------------------------------------------------------------------ async def get_actions(self, observation_dict: dict[str, Any], instruction: str, **kwargs) -> list[dict[str, Any]]: """Get actions from policy given observation and instruction. @@ -637,9 +629,7 @@ async def get_actions(self, observation_dict: dict[str, Any], instruction: str, return self._tensor_to_action_dicts(action_tensor) - # ------------------------------------------------------------------ # Observation batch building - # ------------------------------------------------------------------ def _fixup_preprocessed_batch(self, batch: dict[str, Any]) -> dict[str, Any]: """Fix up a preprocessor-produced batch so every value is a proper batched tensor. @@ -666,7 +656,7 @@ def _fixup_preprocessed_batch(self, batch: dict[str, Any]) -> dict[str, Any]: fixed: dict[str, Any] = {} for key, val in batch.items(): - # --- numpy arrays → torch tensors --- + # numpy arrays → torch tensors if isinstance(val, np.ndarray): if "image" in key: # HWC uint8 → CHW float32 → (1,C,H,W) @@ -682,7 +672,7 @@ def _fixup_preprocessed_batch(self, batch: dict[str, Any]) -> dict[str, Any]: t = t.unsqueeze(0) # (D,) → (1,D) fixed[key] = t.to(device) - # --- torch tensors: ensure batch dim + device --- + # torch tensors: ensure batch dim + device elif isinstance(val, torch.Tensor): # Auto-cast float64 → float32: ROS/dynamixel drivers often produce float64 t = val.float() if val.dtype == torch.float64 else val @@ -695,7 +685,7 @@ def _fixup_preprocessed_batch(self, batch: dict[str, Any]) -> dict[str, Any]: t = t.unsqueeze(0) # (D,) → (1,D) fixed[key] = t.to(device) - # --- pass through anything else (strings, etc.) --- + # pass through anything else (strings, etc.) else: fixed[key] = val @@ -772,7 +762,7 @@ def _build_batch_from_lerobot_format( - Scalars → float32 tensor with batch dim Non-numeric types (strings, pre-batched int64 tokens) are passed through - unchanged — LeRobot expects these as-is for task descriptions and + unchanged - LeRobot expects these as-is for task descriptions and pre-tokenized inputs. Args: @@ -813,7 +803,7 @@ def _build_batch_from_lerobot_format( is_image = True if is_image and tensor.dim() == 3 and tensor.shape[-1] in (1, 3, 4): tensor = tensor.permute(2, 0, 1) - # uint8 images are [0, 255] — normalize to [0, 1] for model input + # uint8 images are [0, 255] - normalize to [0, 1] for model input if is_image and value.dtype == np.uint8: tensor = tensor / 255.0 if is_image and tensor.dim() == 3: @@ -829,7 +819,7 @@ def _build_batch_from_lerobot_format( try: array = np.array(value, dtype=np.float32) except (ValueError, TypeError): - # Non-numeric lists (e.g. string lists) — skip silently, they aren't tensor data + # Non-numeric lists (e.g. string lists) - skip silently, they aren't tensor data logger.debug("Skipping non-numeric list/tuple for key in observation batch") continue tensor = torch.from_numpy(array).float() @@ -868,7 +858,7 @@ def _build_batch_from_strands_format( """ if not self.robot_state_keys: raise ValueError( - "robot_state_keys is empty — cannot map observation to state tensor. " + "robot_state_keys is empty - cannot map observation to state tensor. " "Call set_robot_state_keys() with the robot's motor names." ) @@ -895,7 +885,7 @@ def _build_batch_from_strands_format( expected_dim = state_feature.shape[0] if hasattr(state_feature, "shape") else len(state_values) if len(state_values) > expected_dim: logger.warning( - "State dim %d > model expects %d — truncating to first %d values. " + "State dim %d > model expects %d - truncating to first %d values. " "Check that robot_state_keys matches your robot's actual joint count.", len(state_values), expected_dim, @@ -904,7 +894,7 @@ def _build_batch_from_strands_format( state_values = state_values[:expected_dim] elif len(state_values) < expected_dim: logger.warning( - "State dim %d < model expects %d — zero-padding with %d zeros. " + "State dim %d < model expects %d - zero-padding with %d zeros. " "Check that robot_state_keys matches your robot's actual joint count.", len(state_values), expected_dim, @@ -936,9 +926,7 @@ def _build_batch_from_strands_format( return batch - # ------------------------------------------------------------------ # Action conversion - # ------------------------------------------------------------------ def _tensor_to_action_dicts(self, action_tensor: torch.Tensor) -> list[dict[str, Any]]: """Convert action tensor to list of robot action dicts. diff --git a/strands_robots/policies/lerobot_local/processor.py b/strands_robots/policies/lerobot_local/processor.py index 30362b6..7b27237 100644 --- a/strands_robots/policies/lerobot_local/processor.py +++ b/strands_robots/policies/lerobot_local/processor.py @@ -127,7 +127,7 @@ def from_pretrained( ) logger.info("Loaded preprocessor from %s: %d steps", pretrained_name_or_path, len(preprocessor)) except (FileNotFoundError, ValueError) as exc: - # No config file found — model doesn't ship a preprocessor. This is normal. + # No config file found - model doesn't ship a preprocessor. This is normal. logger.debug("No preprocessor found: %s", exc) # Load postprocessor @@ -139,7 +139,7 @@ def from_pretrained( ) logger.info("Loaded postprocessor from %s: %d steps", pretrained_name_or_path, len(postprocessor)) except (FileNotFoundError, ValueError) as exc: - # No config file found — model doesn't ship a postprocessor. This is normal. + # No config file found - model doesn't ship a postprocessor. This is normal. logger.debug("No postprocessor found: %s", exc) return cls( diff --git a/strands_robots/policies/lerobot_local/resolution.py b/strands_robots/policies/lerobot_local/resolution.py index 401b666..9783e90 100644 --- a/strands_robots/policies/lerobot_local/resolution.py +++ b/strands_robots/policies/lerobot_local/resolution.py @@ -35,7 +35,7 @@ def _ensure_policy_configs_registered() -> None: each config module has module-level side effects that populate the registry. This function imports a single known config to bootstrap the entire registry. - It's safe to call multiple times — the import is idempotent. + It's safe to call multiple times - the import is idempotent. """ global _CONFIGS_REGISTERED if _CONFIGS_REGISTERED: @@ -49,7 +49,7 @@ def _ensure_policy_configs_registered() -> None: _CONFIGS_REGISTERED = True logger.debug("LeRobot policy configs registered in draccus choice registry") except (ImportError, ModuleNotFoundError): - # Pre-0.5 lerobot or missing policy subpackage — that's OK, + # Pre-0.5 lerobot or missing policy subpackage - that's OK, # the caller will fall through to manual resolution. logger.debug("Could not import lerobot policy configs for draccus registration") except Exception as exc: @@ -97,7 +97,7 @@ class lookup, and weight loading via the draccus config registry. logger.debug("PreTrainedConfig resolution failed, trying manual: %s", exc) except Exception as exc: # draccus raises DecodingError/ParsingError which are NOT subclasses - # of RuntimeError/ValueError — they inherit from DraccusException → Exception. + # of RuntimeError/ValueError - they inherit from DraccusException → Exception. # Catch broadly here but only for draccus-related errors. if "draccus" in type(exc).__module__ or "DecodingError" in type(exc).__name__: logger.debug("PreTrainedConfig draccus error, trying manual: %s", exc) @@ -124,7 +124,7 @@ def _ensure_lerobot_policies_importable() -> None: its ``__init__.py``. LeRobot 0.5+ has a ``lerobot/policies/__init__.py`` that eagerly imports - **all** policy packages (groot, act, diffusion, …). The groot import chain + **all** policy packages (groot, act, diffusion, ...). The groot import chain pulls in ``transformers`` → ``flash_attn`` which can crash at module load time on environments with ABI mismatches (e.g. wrong torch / flash-attn version combo). @@ -145,7 +145,7 @@ def _ensure_lerobot_policies_importable() -> None: key = "lerobot.policies" if key in sys.modules: - # Already imported (successfully or via a previous stub) — nothing to do. + # Already imported (successfully or via a previous stub) - nothing to do. return try: @@ -234,7 +234,7 @@ def resolve_policy_class_by_name(policy_type: str) -> type[Any]: except (ImportError, AttributeError, RuntimeError): pass - # Strategy 4: PreTrainedPolicy — only if it's NOT abstract + # Strategy 4: PreTrainedPolicy - only if it's NOT abstract try: from lerobot.policies.pretrained import PreTrainedPolicy diff --git a/strands_robots/policies/mock.py b/strands_robots/policies/mock.py index e4fa38c..6c97b02 100644 --- a/strands_robots/policies/mock.py +++ b/strands_robots/policies/mock.py @@ -1,4 +1,4 @@ -"""Mock policy for testing — generates smooth sinusoidal trajectories.""" +"""Mock policy for testing - generates smooth sinusoidal trajectories.""" import logging import math @@ -10,7 +10,7 @@ class MockPolicy(Policy): - """Mock policy for testing — generates smooth sinusoidal trajectories.""" + """Mock policy for testing - generates smooth sinusoidal trajectories.""" def __init__(self, **kwargs: Any) -> None: self.robot_state_keys: list[str] = [] @@ -21,6 +21,11 @@ def __init__(self, **kwargs: Any) -> None: def provider_name(self) -> str: return "mock" + @property + def requires_images(self) -> bool: + """Mock policy only consumes joint state - skip camera rendering.""" + return False + def set_robot_state_keys(self, robot_state_keys: list[str]) -> None: self.robot_state_keys = robot_state_keys diff --git a/strands_robots/registry/__init__.py b/strands_robots/registry/__init__.py index 2d6ba3d..ef36cb5 100644 --- a/strands_robots/registry/__init__.py +++ b/strands_robots/registry/__init__.py @@ -1,4 +1,4 @@ -"""Unified Registry — single source of truth for robots and policies. +"""Unified Registry - single source of truth for robots and policies. Loads robot definitions and policy provider configs from JSON files. @@ -6,7 +6,7 @@ - **One file to edit**: Add a robot → edit robots.json, done. - **Hot-reload**: JSON is re-read when the file changes (mtime check). - **Self-contained entries**: Each robot/policy owns its aliases, - shorthands, and URL patterns — no separate lookup tables. + shorthands, and URL patterns - no separate lookup tables. - **Validation**: Duplicate aliases, shorthands, and URL patterns are caught on load with clear error messages. diff --git a/strands_robots/registry/policies.py b/strands_robots/registry/policies.py index 07ea6f2..c525b90 100644 --- a/strands_robots/registry/policies.py +++ b/strands_robots/registry/policies.py @@ -1,4 +1,4 @@ -"""Policy registry — resolve, import, and configure policy providers. +"""Policy registry - resolve, import, and configure policy providers. All provider definitions live in policies.json. This module provides the public read API for resolving smart policy strings, importing provider @@ -63,7 +63,7 @@ def resolve_policy(policy: str, **extra_kwargs) -> tuple[str, dict[str, Any]]: 5. Fallback to lerobot_local Args: - policy: Smart string — HF model ID, URL, or provider name. + policy: Smart string - HF model ID, URL, or provider name. **extra_kwargs: Additional kwargs merged into result. Returns: @@ -84,7 +84,7 @@ def resolve_policy(policy: str, **extra_kwargs) -> tuple[str, dict[str, Any]]: policy = policy.strip() kwargs: dict[str, Any] = {} - # 1. URL pattern matching — check each provider's url_patterns + # 1. URL pattern matching - check each provider's url_patterns for prov_name, prov_info in providers.items(): for pattern in prov_info.get("url_patterns", []): if re.match(pattern, policy): @@ -105,7 +105,7 @@ def resolve_policy(policy: str, **extra_kwargs) -> tuple[str, dict[str, Any]]: kwargs.update(extra_kwargs) return prov_name, kwargs - # 2. Shorthand names — built from each provider's shorthands list + # 2. Shorthand names - built from each provider's shorthands list alias_map = _build_alias_map() if policy.lower() in alias_map: kwargs.update(extra_kwargs) diff --git a/strands_robots/registry/robots.py b/strands_robots/registry/robots.py index 779f5fa..81d873b 100644 --- a/strands_robots/registry/robots.py +++ b/strands_robots/registry/robots.py @@ -1,4 +1,4 @@ -"""Robot registry — query, resolve, and list robot definitions. +"""Robot registry - query, resolve, and list robot definitions. All robot definitions live in robots.json. This module provides the public read API; the JSON file is the only thing you edit to add @@ -55,7 +55,7 @@ def get_robot(name: str) -> dict[str, Any] | None: Returns: Robot dict with keys like description, category, joints, asset, - hardware — or None if not found. + hardware - or None if not found. """ reg = _load("robots") canonical = resolve_name(name) @@ -92,7 +92,7 @@ def list_robots(mode: str = "all") -> list[dict[str, Any]]: """List available robots, optionally filtered. Args: - mode: Filter — "all", "sim", "real", or "both" (has sim AND real). + mode: Filter - "all", "sim", "real", or "both" (has sim AND real). Returns: List of dicts with name, description, has_sim, has_real. @@ -137,19 +137,54 @@ def list_aliases() -> dict[str, str]: return _build_alias_map() -def format_robot_table() -> str: - """Human-readable table of all robots for CLI/tool output.""" - lines = [ - f"{'Name':<20} {'Category':<15} {'Joints':<8} {'Sim':<5} {'Real':<5} Description", - "─" * 100, - ] - for cat in ["arm", "bimanual", "hand", "humanoid", "expressive", "mobile", "mobile_manip"]: +_NAME_WIDTH = 20 +_CAT_WIDTH = 15 +_JOINTS_WIDTH = 8 +_SIM_WIDTH = 5 +_REAL_WIDTH = 5 +# Width of the fixed prefix columns, including single-space separators. +_FIXED_PREFIX_WIDTH = _NAME_WIDTH + 1 + _CAT_WIDTH + 1 + _JOINTS_WIDTH + 1 + _SIM_WIDTH + 1 + _REAL_WIDTH + 1 + + +def format_robot_table(max_width: int = 100) -> str: + """Human-readable table of all robots for CLI/tool output. + + Args: + max_width: Target terminal width. The ``Description`` column is + truncated with an ellipsis to fit. Pass a large value (e.g. + ``1000``) to disable truncation entirely. Default 100 is safe + for a typical 100-column terminal. + """ + desc_width = max(20, max_width - _FIXED_PREFIX_WIDTH) + + header = ( + f"{'Name':<{_NAME_WIDTH}} " + f"{'Category':<{_CAT_WIDTH}} " + f"{'Joints':<{_JOINTS_WIDTH}} " + f"{'Sim':<{_SIM_WIDTH}} " + f"{'Real':<{_REAL_WIDTH}} " + f"Description" + ) + rule_width = min(max(max_width, len(header)), _FIXED_PREFIX_WIDTH + desc_width) + lines = [header, "─" * rule_width] + + for cat in ["arm", "bimanual", "hand", "humanoid", "expressive", "mobile", "mobile_manip", "aerial"]: by_cat = list_robots_by_category() for r in by_cat.get(cat, []): sim = "✅" if r["has_sim"] else " " real = "✅" if r["has_real"] else " " joints = str(r["joints"]) if r["joints"] else "?" - lines.append(f"{r['name']:<20} {r['category']:<15} {joints:<8} {sim:<5} {real:<5} {r['description']}") + desc = r["description"] or "" + if len(desc) > desc_width: + desc = desc[: desc_width - 3].rstrip() + "..." + lines.append( + f"{r['name']:<{_NAME_WIDTH}} " + f"{r['category']:<{_CAT_WIDTH}} " + f"{joints:<{_JOINTS_WIDTH}} " + f"{sim:<{_SIM_WIDTH}} " + f"{real:<{_REAL_WIDTH}} " + f"{desc}" + ) robots = list_robots() lines.append("") diff --git a/strands_robots/registry/user_registry.py b/strands_robots/registry/user_registry.py index eb55843..8f67f14 100644 --- a/strands_robots/registry/user_registry.py +++ b/strands_robots/registry/user_registry.py @@ -1,4 +1,4 @@ -"""User-local robot registry — runtime registration without editing package JSON. +"""User-local robot registry - runtime registration without editing package JSON. Provides ``register_robot()`` and ``unregister_robot()`` for adding custom robots that persist across sessions via a ``user_robots.json`` file stored @@ -13,7 +13,7 @@ user registry. Use ``STRANDS_BASE_DIR`` to relocate user metadata. At load time the user overlay is merged *on top of* the package -``robots.json`` — user entries win on name collision, so you can also +``robots.json`` - user entries win on name collision, so you can also override built-in robots locally. Usage:: @@ -34,7 +34,7 @@ from strands_robots.simulation import create_simulation sim = create_simulation() sim.create_world() - sim.add_robot("my_arm") # ✅ auto-resolved + sim.add_robot("my_arm") # auto-resolved # Remove it unregister_robot("my_arm") @@ -176,7 +176,7 @@ def register_robot( if _pkg_get_robot(name) is not None: logger.info( - "Robot '%s' exists in package registry — user registration will override it.", + "Robot '%s' exists in package registry - user registration will override it.", name, ) except ImportError: @@ -189,7 +189,7 @@ def register_robot( # This matches how resolve_model_path works: search_dir / asset["dir"] / xml dir_name = resolved_dir.name - # Alias collision detection — warn (don't fail) when a user alias shadows a + # Alias collision detection - warn (don't fail) when a user alias shadows a # canonical name or another alias. Doing this at registration surfaces the # problem immediately instead of at silent resolution-order time. if aliases and not overwrite: @@ -218,7 +218,7 @@ def register_robot( logger.warning("Alias %r is already used by another robot.", alias) # Validate model_xml exists. Previously we only checked when - # ``resolved_dir`` existed — which silently accepted registrations for + # ``resolved_dir`` existed - which silently accepted registrations for # dirs that didn't exist yet and surfaced a confusing error only at # ``add_robot()`` time. Now we fail-closed on both conditions so the # user gets an immediate, actionable error at registration time. @@ -283,7 +283,7 @@ def unregister_robot(name: str) -> bool: data = _load_user_registry() if name not in data.get("robots", {}): - logger.info("Robot '%s' not in user registry — nothing to remove.", name) + logger.info("Robot '%s' not in user registry - nothing to remove.", name) return False del data["robots"][name] diff --git a/strands_robots/robot.py b/strands_robots/robot.py index 33c5ba7..3927afe 100644 --- a/strands_robots/robot.py +++ b/strands_robots/robot.py @@ -236,7 +236,7 @@ async def _connect_robot(self) -> tuple[bool, str]: # Check if already connected if self.robot.is_connected: - logger.info(f"✅ {self.robot} already connected") + logger.info(f"{self.robot} already connected") return True, "" logger.info(f"🔌 Connecting to {self.robot}...") @@ -248,13 +248,13 @@ async def _connect_robot(self) -> tuple[bool, str]: except DeviceAlreadyConnectedError: # This is expected and fine - robot is already connected - logger.info(f"✅ {self.robot} was already connected") + logger.info(f"{self.robot} was already connected") except Exception as e: # Check if it's the string version of "already connected" error error_str = str(e).lower() if "already connected" in error_str or "is already connected" in error_str: - logger.info(f"✅ {self.robot} connection already established") + logger.info(f"{self.robot} connection already established") else: # Re-raise if it's a different error raise e @@ -262,7 +262,7 @@ async def _connect_robot(self) -> tuple[bool, str]: # Final connection check if not self.robot.is_connected: error_msg = f"Failed to connect to {self.robot}" - logger.error(f"❌ {error_msg}") + logger.error(f"{error_msg}") return False, error_msg # Check robot calibration @@ -271,15 +271,15 @@ async def _connect_robot(self) -> tuple[bool, str]: f"Robot {self.robot} is not calibrated. Please calibrate the robot manually" " first using LeRobot's calibration process (lerobot-calibrate)" ) - logger.error(f"❌ {error_msg}") + logger.error(f"{error_msg}") return False, error_msg - logger.info(f"✅ {self.robot} connected and ready") + logger.info(f"{self.robot} connected and ready") return True, "" except Exception as e: error_msg = f"Robot connection failed: {e}. Ensure robot is calibrated and accessible on the specified port" - logger.error(f"❌ {error_msg}") + logger.error(f"{error_msg}") return False, error_msg async def _initialize_policy(self, policy: Policy) -> bool: @@ -300,7 +300,7 @@ async def _initialize_policy(self, policy: Policy) -> bool: return True except Exception as e: - logger.error(f"❌ Failed to initialize policy: {e}") + logger.error(f"Failed to initialize policy: {e}") return False async def _execute_task_async( @@ -370,12 +370,10 @@ async def _execute_task_async( if self._task_state.status == TaskStatus.RUNNING: self._task_state.status = TaskStatus.COMPLETED - logger.info( - f"✅ Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)" - ) + logger.info(f"Task completed: '{instruction}' in {elapsed:.1f}s ({self._task_state.step_count} steps)") except Exception as e: - logger.error(f"❌ Task execution failed: {e}") + logger.error(f"Task execution failed: {e}") self._task_state.status = TaskStatus.ERROR self._task_state.error_message = str(e) @@ -415,12 +413,12 @@ async def task_runner(): "status": "success" if self._task_state.status == TaskStatus.COMPLETED else "error", "content": [ { - "text": f"✅ Task: '{instruction}' - {self._task_state.status.value}\n" - f"🤖 Robot: {self.tool_name_str} ({self.robot})\n" - f"🧠 Policy: {policy_provider} on {policy_host}:{policy_port}\n" - f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - f"🎯 Steps: {self._task_state.step_count}" - + (f"\n❌ Error: {self._task_state.error_message}" if self._task_state.error_message else "") + "text": f"Task: '{instruction}' - {self._task_state.status.value}\n" + f"Robot: {self.tool_name_str} ({self.robot})\n" + f"Policy: {policy_provider} on {policy_host}:{policy_port}\n" + f"Duration: {self._task_state.duration:.1f}s\n" + f"Steps: {self._task_state.step_count}" + + (f"\nError: {self._task_state.error_message}" if self._task_state.error_message else "") } ], } @@ -439,7 +437,7 @@ def start_task( if self._task_state.status == TaskStatus.RUNNING: return { "status": "error", - "content": [{"text": f"❌ Task already running: {self._task_state.instruction}"}], + "content": [{"text": f"Task already running: {self._task_state.instruction}"}], } # Start task in background @@ -451,10 +449,10 @@ def start_task( "status": "success", "content": [ { - "text": f"🚀 Task started: '{instruction}'\n" - f"🤖 Robot: {self.tool_name_str}\n" - f"💡 Use action='status' to check progress\n" - f"💡 Use action='stop' to interrupt" + "text": f"Task started: '{instruction}'\n" + f"Robot: {self.tool_name_str}\n" + f"Use action='status' to check progress\n" + f"Use action='stop' to interrupt" } ], } @@ -466,20 +464,20 @@ def get_task_status(self) -> dict[str, Any]: if self._task_state.status == TaskStatus.RUNNING: self._task_state.duration = time.time() - self._task_state.start_time - status_text = f"📊 Robot Status: {self._task_state.status.value.upper()}\n" + status_text = f"Robot Status: {self._task_state.status.value.upper()}\n" if self._task_state.instruction: - status_text += f"🎯 Task: {self._task_state.instruction}\n" + status_text += f"Task: {self._task_state.instruction}\n" if self._task_state.status == TaskStatus.RUNNING: - status_text += f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - status_text += f"🔄 Steps: {self._task_state.step_count}\n" + status_text += f"Duration: {self._task_state.duration:.1f}s\n" + status_text += f"Steps: {self._task_state.step_count}\n" elif self._task_state.status in [TaskStatus.COMPLETED, TaskStatus.STOPPED, TaskStatus.ERROR]: - status_text += f"⏱️ Total Duration: {self._task_state.duration:.1f}s\n" - status_text += f"🎯 Total Steps: {self._task_state.step_count}\n" + status_text += f"Total Duration: {self._task_state.duration:.1f}s\n" + status_text += f"Total Steps: {self._task_state.step_count}\n" if self._task_state.error_message: - status_text += f"❌ Error: {self._task_state.error_message}\n" + status_text += f"Error: {self._task_state.error_message}\n" return { "status": "success", @@ -502,15 +500,15 @@ def stop_task(self) -> dict[str, Any]: if self._task_state.task_future: self._task_state.task_future.cancel() - logger.info(f"🛑 Task stopped: {self._task_state.instruction}") + logger.info(f"Task stopped: {self._task_state.instruction}") return { "status": "success", "content": [ { - "text": f"🛑 Task stopped: '{self._task_state.instruction}'\n" - f"⏱️ Duration: {self._task_state.duration:.1f}s\n" - f"🎯 Steps completed: {self._task_state.step_count}" + "text": f"Task stopped: '{self._task_state.instruction}'\n" + f"Duration: {self._task_state.duration:.1f}s\n" + f"Steps completed: {self._task_state.step_count}" } ], } @@ -601,7 +599,7 @@ async def stream( tool_use_id, { "status": "error", - "content": [{"text": "❌ instruction and policy_port are required for execute action"}], + "content": [{"text": "Instruction and policy_port are required for execute action"}], }, ) ) @@ -625,7 +623,7 @@ async def stream( tool_use_id, { "status": "error", - "content": [{"text": "❌ instruction and policy_port are required for start action"}], + "content": [{"text": "Instruction and policy_port are required for start action"}], }, ) ) @@ -652,20 +650,20 @@ async def stream( { "status": "error", "content": [ - {"text": f"❌ Unknown action: {action}. Valid actions: execute, start, status, stop"} + {"text": f"Unknown action: {action}. Valid actions: execute, start, status, stop"} ], }, ) ) except Exception as e: - logger.error(f"❌ {self.tool_name_str} error: {e}") + logger.error(f"{self.tool_name_str} error: {e}") yield ToolResultEvent( self._make_tool_result( tool_use_id, { "status": "error", - "content": [{"text": f"❌ {self.tool_name_str} error: {str(e)}"}], + "content": [{"text": f"{self.tool_name_str} error: {str(e)}"}], }, ) ) @@ -686,7 +684,7 @@ def cleanup(self): logger.info(f"🧹 {self.tool_name_str} cleanup completed") except Exception as e: - logger.error(f"❌ Cleanup error for {self.tool_name_str}: {e}") + logger.error(f"Cleanup error for {self.tool_name_str}: {e}") def __del__(self): """Destructor to ensure cleanup.""" @@ -730,7 +728,7 @@ async def get_status(self) -> dict[str, Any]: return status_data except Exception as e: - logger.error(f"❌ Error getting status for {self.tool_name_str}: {e}") + logger.error(f"Error getting status for {self.tool_name_str}: {e}") return { "robot_name": self.tool_name_str, "error": str(e), @@ -752,7 +750,7 @@ async def stop(self): # Cleanup resources self.cleanup() - logger.info(f"🛑 {self.tool_name_str} stopped and disconnected") + logger.info(f"{self.tool_name_str} stopped and disconnected") except Exception as e: - logger.error(f"❌ Error stopping robot: {e}") + logger.error(f"Error stopping robot: {e}") diff --git a/strands_robots/simulation/__init__.py b/strands_robots/simulation/__init__.py index d9674a9..c272008 100644 --- a/strands_robots/simulation/__init__.py +++ b/strands_robots/simulation/__init__.py @@ -1,15 +1,25 @@ -"""Strands Robots Simulation — multi-backend simulation framework. +"""Strands Robots Simulation - multi-backend simulation framework. Architecture:: simulation/ - ├── __init__.py ← this file (re-exports, lazy loading) - ├── base.py ← SimEngine ABC - ├── factory.py ← create_simulation() + backend registration - ├── models.py ← shared dataclasses (SimWorld, SimRobot, ...) - └── model_registry.py ← URDF/MJCF resolution (shared across backends) - - # MuJoCo backend added in subsequent PRs. + ├ __init__.py ← this file (re-exports, lazy loading) + ├ base.py ← SimEngine ABC + ├ factory.py ← create_simulation() + backend registration + ├ models.py ← shared dataclasses (SimWorld, SimRobot, ...) + ├ model_registry.py ← URDF/MJCF resolution (shared across backends) + └ mujoco/ ← MuJoCo CPU backend + ├ __init__.py + ├ backend.py ← lazy mujoco import + GL config + ├ spec_builder.py ← MjSpec-based scene builder/mutator + ├ physics.py ← advanced physics (raycasting, jacobians, forces) + ├ scene_ops.py ← live scene mutation via spec.recompile() + ├ rendering.py ← render RGB/depth, observations + ├ policy_runner.py ← run_policy, eval_policy, replay + ├ randomization.py ← domain randomization + ├ recording.py ← LeRobotDataset recording + ├ tool_spec.json ← AgentTool input schema + └ simulation.py ← Simulation (AgentTool orchestrator) Usage:: @@ -39,7 +49,7 @@ import importlib as _importlib from typing import Any -# --- Light imports (no heavy deps — stdlib + dataclasses only) --- +# Light imports (no heavy deps - stdlib + dataclasses only) from strands_robots.simulation.base import SimEngine from strands_robots.simulation.factory import ( create_simulation, @@ -62,10 +72,15 @@ TrajectoryStep, ) -# --- Heavy imports (lazy — loaded when mujoco backend is available) --- -# MuJoCo-specific lazy imports will be added when the mujoco/ subpackage -# is introduced. For now, only the lightweight foundation is available. -_LAZY_IMPORTS: dict[str, tuple[str, str]] = {} +# Heavy imports (lazy - need strands SDK + mujoco) +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "Simulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"), + "MuJoCoSimulation": ("strands_robots.simulation.mujoco.simulation", "Simulation"), + "SpecBuilder": ("strands_robots.simulation.mujoco.spec_builder", "SpecBuilder"), + "_configure_gl_backend": ("strands_robots.simulation.mujoco.backend", "_configure_gl_backend"), + "_ensure_mujoco": ("strands_robots.simulation.mujoco.backend", "_ensure_mujoco"), + "_is_headless": ("strands_robots.simulation.mujoco.backend", "_is_headless"), +} __all__ = [ @@ -75,9 +90,9 @@ "create_simulation", "list_backends", "register_backend", - # Default backend alias (available when mujoco backend is installed) - # "Simulation", - # "MuJoCoSimulation", + # Default backend alias + "Simulation", + "MuJoCoSimulation", # Shared dataclasses "SimStatus", "SimRobot", @@ -85,8 +100,8 @@ "SimCamera", "SimWorld", "TrajectoryStep", - # MuJoCo builder (available when mujoco backend is installed) - # "MJCFBuilder", + # MuJoCo scene builder (MjSpec-based, replaces MJCFBuilder) + "SpecBuilder", # Model registry "register_urdf", "resolve_model", @@ -108,4 +123,4 @@ def __getattr__(name: str) -> Any: # NOTE: MuJoCo GL backend configuration lives in the top-level # strands_robots/__init__.py to ensure it runs before any `import mujoco`. -# Do NOT duplicate it here — see PR #86 for the canonical location. +# Do NOT duplicate it here - see PR #86 for the canonical location. diff --git a/strands_robots/simulation/base.py b/strands_robots/simulation/base.py index 7ca2098..71ee141 100644 --- a/strands_robots/simulation/base.py +++ b/strands_robots/simulation/base.py @@ -1,7 +1,7 @@ -"""Simulation ABC — backend-agnostic interface for all simulation engines. +"""Simulation ABC - backend-agnostic interface for all simulation engines. Every simulation backend (MuJoCo, Isaac, Newton) implements this interface. -Agent tools and the Robot() factory interact through these methods only — +Agent tools and the Robot() factory interact through these methods only - they never touch backend-specific APIs directly. Usage:: @@ -20,7 +20,10 @@ import logging from abc import ABC, abstractmethod -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from strands_robots.policies import Policy logger = logging.getLogger(__name__) @@ -29,18 +32,25 @@ class SimEngine(ABC): """Abstract base class for simulation engines. Defines the contract that all backends (MuJoCo, Isaac, Newton) must - implement. This is the *programmatic* API — the AgentTool layer + implement. This is the *programmatic* API - the AgentTool layer wraps it with tool_spec/stream for LLM access. Method categories: - **Required** (``@abstractmethod``): Core simulation loop — world - lifecycle, entity management, observation/action, rendering. Every - physics engine must implement these to be usable. + **Required** (``@abstractmethod``): Core simulation loop - world + lifecycle, entity management, observation/action, rendering, robot + discovery. Every physics engine must implement these to be usable. + + **Provided** (concrete base-class methods): Policy orchestration + (``run_policy`` / ``start_policy`` / ``replay_episode`` / ``eval_policy``) + is implemented once in this ABC as a facade over the abstract primitives. + Backends inherit them for free by implementing the primitives. They + *may* override for backend-specific optimisations (e.g. GPU-batched + policy inference on Isaac). **Optional** (default raises ``NotImplementedError``): Higher-level - features — scene loading, policy running, domain randomization, - contact queries. Backends opt in by overriding only what they support. + features - scene loading, domain randomization, contact queries. + Backends opt in by overriding only what they support. Lifecycle:: @@ -61,7 +71,7 @@ class SimEngine(ABC): sim.destroy() """ - # --- World lifecycle --- + # World lifecycle @abstractmethod def create_world( @@ -93,7 +103,7 @@ def get_state(self) -> dict[str, Any]: """Get full simulation state summary.""" ... - # --- Robot management --- + # Robot management @abstractmethod def add_robot( @@ -112,7 +122,26 @@ def remove_robot(self, name: str) -> dict[str, Any]: """Remove a robot from the simulation.""" ... - # --- Object management --- + @abstractmethod + def list_robots(self) -> list[str]: + """Return ordered list of robot names currently in the world. + + Used by the backend-agnostic ``PolicyRunner`` to resolve a + default robot when the caller omits ``robot_name``. + """ + ... + + @abstractmethod + def robot_joint_names(self, robot_name: str) -> list[str]: + """Return ordered joint names for ``robot_name``. + + Used by ``Policy.set_robot_state_keys`` and by + ``PolicyRunner.replay`` to map dataset action-vector indices to + named joints. Order must match the backend's action ordering. + """ + ... + + # Object management @abstractmethod def add_object( @@ -136,16 +165,39 @@ def remove_object(self, name: str) -> dict[str, Any]: """Remove an object from the scene.""" ... - # --- Observation / Action --- + # Observation / Action @abstractmethod - def get_observation(self, robot_name: str | None = None, camera_name: str | None = None) -> dict[str, Any]: - """Get observation from simulation. - - Convenience method that delegates to the underlying Robot - abstraction. Provides a unified interface for agent tools - that interact with simulation without needing to distinguish - between Robot and Sim layers. + def get_observation(self, robot_name: str | None = None, *, skip_images: bool = False) -> dict[str, Any]: + """Get full observation for a robot: joint state + all attached cameras. + + Unified observation consumed by :class:`Policy` and + :class:`~strands_robots.simulation.policy_runner.PolicyRunner`. + Backends MUST return a dict with the following schema; extra keys + are allowed. + + Schema: + - ``""`` (float): One entry per joint on the robot, + keyed by the *short* joint name (e.g. ``"shoulder_pan"``). + The schema is stable regardless of multi-robot namespacing + at the physics-engine level. + - ``""`` (np.ndarray): One RGB uint8 frame per + camera associated with the robot, keyed by camera name. + Shape ``(H, W, 3)``. Cameras whose render fails MAY be + omitted; joint state MUST still be returned. + + Single-camera rendering is :meth:`render`'s job, not this method's. + For batched multi-robot observation (future Isaac / Newton), add a + separate ``get_observations(robot_names)`` method - do NOT extend + this one. + + Args: + robot_name: Which robot to observe. If ``None`` and exactly one + robot exists, that robot is used; otherwise returns ``{}``. + + Returns: + Observation dict per schema above. Returns ``{}`` if the world + is not yet created or ``robot_name`` is unknown. """ ... @@ -157,10 +209,14 @@ def send_action(self, action: dict[str, Any], robot_name: str | None = None, n_s abstraction. The simulation engine acts as a facade so agent tools can use ``sim.send_action()`` without knowing about the Robot/Policy layer. + + Backends are responsible for internal thread-safety (e.g. + MuJoCo must acquire an internal lock here). ``PolicyRunner`` + does not manage locks. """ ... - # --- Rendering --- + # Rendering @abstractmethod def render( @@ -174,22 +230,241 @@ def render( """ ... - # --- Optional overrides (have default no-op implementations) --- + # Policy orchestration (concrete facade, not abstract) - def load_scene(self, scene_path: str) -> dict[str, Any]: - """Load a complete scene from file. Override per backend.""" - raise NotImplementedError("load_scene not implemented by this backend") + def run_policy( + self, + robot_name: str, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + duration: float = 10.0, + control_frequency: float = 50.0, + action_horizon: int = 8, + fast_mode: bool = False, + video: dict[str, Any] | None = None, + policy_object: Policy | None = None, + n_steps: int | None = None, + max_steps: int | None = None, + max_onframe_failures: int | None = None, + ) -> dict[str, Any]: + """Run a policy loop in the simulation (blocking). + + Default implementation delegates to the backend-agnostic + :class:`~strands_robots.simulation.policy_runner.PolicyRunner`. + Backends MAY override for backend-specific optimisations + (e.g. GPU-batched policy inference on Isaac). + + Args: + robot_name: Robot to control. + policy_provider: Name passed to + :func:`strands_robots.policies.create_policy`. + policy_config: Opaque dict of provider-specific kwargs + (``observation_mapping``, ``action_mapping``, ``host``, + ``port``, ``api_token``, ``pretrained_name_or_path``, + ``trust_remote_code``, ``actions_per_step``, + ``use_processor``, ``processor_overrides``, ``device``, + ...). Forwarded verbatim to ``create_policy``. + instruction: Natural-language instruction for the policy. + duration: Wall-clock seconds to run. + control_frequency: Target Hz for policy queries. + action_horizon: Max actions per policy call. + fast_mode: Skip real-time sleep between steps. + video: Optional video-recording config dict. Accepted keys: + ``path`` (str, output MP4 - required to enable recording), + ``fps`` (int, default 30), ``camera`` (str, default backend + default), ``width`` (int, default 640), ``height`` (int, + default 480). See :class:`~strands_robots.simulation.policy_runner.VideoConfig`. + For extension points beyond video (custom telemetry, + dataset recording), backends plug into + ``PolicyRunner.run``'s ``on_frame`` hook via + :meth:`_make_run_policy_hook`. + + Returns: + Standard status dict. + """ + from strands_robots.policies import create_policy + from strands_robots.simulation.policy_runner import PolicyRunner, VideoConfig + + # accept n_steps (or legacy max_steps) as an alternate horizon + # specification. duration = n_steps / control_frequency. If both + # are passed, n_steps wins (primary per DoD). + if n_steps is None and max_steps is not None: + n_steps = int(max_steps) + if n_steps is not None: + if n_steps <= 0: + return { + "status": "error", + "content": [{"text": f"run_policy: n_steps must be > 0, got {n_steps}."}], + } + if control_frequency <= 0: + return { + "status": "error", + "content": [{"text": "run_policy: control_frequency must be > 0 when n_steps is used."}], + } + duration = float(n_steps) / float(control_frequency) + + if robot_name not in self.list_robots(): + return { + "status": "error", + "content": [{"text": f"Robot '{robot_name}' not found."}], + } + + if policy_object is not None: + # Pre-built policy path - skip the expensive create_policy call. + # Caller is responsible for policy.set_robot_state_keys(...) if needed, + # but we set it here defensively so the semantics match the provider path. + policy = policy_object + else: + policy = create_policy(policy_provider, **(policy_config or {})) + policy.set_robot_state_keys(self.robot_joint_names(robot_name)) + + on_frame = self._make_run_policy_hook(robot_name, instruction) + + return PolicyRunner(self).run( + robot_name, + policy, + instruction=instruction, + duration=duration, + control_frequency=control_frequency, + action_horizon=action_horizon, + fast_mode=fast_mode, + video=VideoConfig.from_dict(video), + on_frame=on_frame, + max_onframe_failures=max_onframe_failures, + ) + + def start_policy( + self, + robot_name: str, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + duration: float = 10.0, + control_frequency: float = 50.0, + action_horizon: int = 8, + fast_mode: bool = False, + video: dict[str, Any] | None = None, + policy_object: Policy | None = None, + n_steps: int | None = None, + max_steps: int | None = None, + ) -> dict[str, Any]: + """Start policy execution in a background thread (non-blocking). - def run_policy(self, robot_name: str, policy_provider: str = "mock", **kwargs: Any) -> dict[str, Any]: - """Run a policy loop in the simulation. + Default implementation: synchronous passthrough to ``run_policy``. + Backends that support true background execution (like MuJoCo via + its ``ThreadPoolExecutor``) should override. - Orchestration shortcut: internally creates a Policy, then loops - ``obs → policy(obs) → send_action(action) → step()``. - Intentionally placed on SimEngine as a facade for agent tools - that need a single ``simulation(action="run_policy")`` interface. - Override per backend. + accepts ``n_steps`` (primary) or legacy ``max_steps`` as an + alternate to ``duration``. See ``run_policy`` for conversion rules. """ - raise NotImplementedError("run_policy not implemented by this backend") + return self.run_policy( + robot_name, + policy_provider=policy_provider, + policy_config=policy_config, + instruction=instruction, + duration=duration, + control_frequency=control_frequency, + action_horizon=action_horizon, + fast_mode=fast_mode, + video=video, + policy_object=policy_object, + n_steps=n_steps, + max_steps=max_steps, + ) + + def replay_episode( + self, + repo_id: str, + robot_name: str | None = None, + episode: int = 0, + root: str | None = None, + speed: float = 1.0, + action_key_map: list[str] | None = None, + ) -> dict[str, Any]: + """Replay a LeRobotDataset episode via ``PolicyRunner.replay``. + + Override per backend for optimised replay (e.g. direct ctrl + writes) only when measured necessary. + """ + from strands_robots.simulation.policy_runner import PolicyRunner + + return PolicyRunner(self).replay( + repo_id, + robot_name=robot_name, + episode=episode, + root=root, + speed=speed, + action_key_map=action_key_map, + ) + + def eval_policy( + self, + robot_name: str | None = None, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + n_episodes: int = 1, + max_steps: int = 300, + success_fn: str | None = None, + ) -> dict[str, Any]: + """Multi-episode policy evaluation via ``PolicyRunner.evaluate``. + + ``robot_name`` is required - eval_policy used to silently pick + the first robot, which is surprising in multi-robot scenes. + ``n_episodes`` default lowered from 10 to 1 (callers opt in to + longer evals explicitly). + """ + from strands_robots.policies import create_policy + from strands_robots.simulation.policy_runner import PolicyRunner + + if not robot_name: + return { + "status": "error", + "content": [{"text": "eval_policy requires 'robot_name'."}], + } + robots = self.list_robots() + if not robots: + return {"status": "error", "content": [{"text": "No robots in sim. Add one first."}]} + if robot_name not in robots: + return { + "status": "error", + "content": [{"text": f"Robot '{robot_name}' not found."}], + } + resolved_robot = robot_name + + policy = create_policy(policy_provider, **(policy_config or {})) + policy.set_robot_state_keys(self.robot_joint_names(resolved_robot)) + + return PolicyRunner(self).evaluate( + resolved_robot, + policy, + instruction=instruction, + n_episodes=n_episodes, + max_steps=max_steps, + success_fn=success_fn, + ) + + def _make_run_policy_hook(self, robot_name: str, instruction: str) -> Any: + """Override to return an ``on_frame(step, obs, action)`` callable. + + Used by backends that want to layer in recording / telemetry + without subclassing :class:`PolicyRunner`. Default: no hook. + + Args: + robot_name: Robot being controlled this run. + instruction: Instruction passed to this run. + + Returns: + Callable or ``None``. + """ + return None + + # Optional overrides (have default no-op implementations) + + def load_scene(self, scene_path: str) -> dict[str, Any]: + """Load a complete scene from file. Override per backend.""" + raise NotImplementedError("load_scene not implemented by this backend") def randomize(self, **kwargs: Any) -> dict[str, Any]: """Apply domain randomization. @@ -217,6 +492,6 @@ def __del__(self) -> None: try: self.cleanup() except Exception as e: - # Best-effort cleanup during GC — exceptions can't propagate + # Best-effort cleanup during GC - exceptions can't propagate # from __del__ (CPython ignores them), so log for visibility. logger.warning("Cleanup error during __del__: %s", e) diff --git a/strands_robots/simulation/factory.py b/strands_robots/simulation/factory.py index e7b0a5b..2d6ab03 100644 --- a/strands_robots/simulation/factory.py +++ b/strands_robots/simulation/factory.py @@ -1,4 +1,4 @@ -"""Simulation factory — create_simulation() and runtime backend registration. +"""Simulation factory - create_simulation() and runtime backend registration. Mirrors the policy factory pattern: JSON-driven defaults with runtime override capability. Backends are lazy-loaded on first use. @@ -34,9 +34,7 @@ logger = logging.getLogger(__name__) -# ───────────────────────────────────────────────────────────────────── -# Built-in backend registry (lazy loaders — no imports at module load) -# ───────────────────────────────────────────────────────────────────── +# Built-in backend registry (lazy loaders - no imports at module load) _BUILTIN_BACKENDS: dict[str, tuple[str, str]] = { "mujoco": ( @@ -59,9 +57,7 @@ DEFAULT_BACKEND = "mujoco" -# ───────────────────────────────────────────────────────────────────── # Runtime registration (for user-defined backends not in built-ins) -# ───────────────────────────────────────────────────────────────────── _runtime_registry: dict[str, Callable[[], type[SimEngine]]] = {} _runtime_aliases: dict[str, str] = {} diff --git a/strands_robots/simulation/model_registry.py b/strands_robots/simulation/model_registry.py index b7af5e9..6a077ed 100644 --- a/strands_robots/simulation/model_registry.py +++ b/strands_robots/simulation/model_registry.py @@ -1,11 +1,11 @@ -"""Robot model resolution — URDF registry + asset manager. +"""Robot model resolution - URDF registry + asset manager. Bridges the robot registry with actual URDF/MJCF files on disk. Resolution order for :func:`resolve_model`: 1. User-registered URDFs (:func:`register_urdf`) 2. URDF search paths (``STRANDS_ASSETS_DIR``, CWD, etc.) - 3. Asset manager (``robot_descriptions`` — fallback for standard robots) + 3. Asset manager (``robot_descriptions`` - fallback for standard robots) """ from __future__ import annotations @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) # URDF search paths are resolved lazily via :func:`strands_robots.utils.get_search_paths` -# at every lookup — this avoids snapshotting ``Path.cwd()`` and ``STRANDS_ASSETS_DIR`` +# at every lookup - this avoids snapshotting ``Path.cwd()`` and ``STRANDS_ASSETS_DIR`` # at import time, which caused silent wrong-path bugs when tests/notebooks chdir after # import. @@ -39,7 +39,7 @@ except ImportError: _HAS_REGISTRY = False -# Logged lazily on first resolution via _log_configuration_once() — +# Logged lazily on first resolution via _log_configuration_once() - # avoids noisy INFO on every ``import strands_robots``. _CONFIG_LOGGED = False @@ -68,7 +68,7 @@ def resolve_model(name: str, prefer_scene: bool = True) -> str | None: Resolution order (local assets take priority): 1. User-registered URDFs (custom user registrations) 2. URDF search paths (STRANDS_ASSETS_DIR, CWD, etc.) - 3. Asset manager (robot_descriptions — fallback for standard robots) + 3. Asset manager (robot_descriptions - fallback for standard robots) """ _log_configuration_once() # 1+2. Check local/custom paths first (user overrides win) @@ -92,7 +92,7 @@ def resolve_model(name: str, prefer_scene: bool = True) -> str | None: def resolve_urdf(data_config: str) -> str | None: """Resolve a data_config name to a URDF file path. - Also checks the registry's ``legacy_urdf`` field — a backward-compatible + Also checks the registry's ``legacy_urdf`` field - a backward-compatible path for robots that were registered before the MJCF asset system was introduced (e.g. robots originally configured with raw URDF paths). """ @@ -137,6 +137,6 @@ def list_available_models() -> str: lines = ["Registered URDFs:"] for name, path in _URDF_REGISTRY.items(): resolved = resolve_urdf(name) - status = "✅" if resolved else "❌" - lines.append(f" {status} {name}: {path}") + status = "[OK]" if resolved else "[MISSING]" + lines.append(f"{status} {name}: {path}") return "\n".join(lines) diff --git a/strands_robots/simulation/models.py b/strands_robots/simulation/models.py index e339d15..55dc2d6 100644 --- a/strands_robots/simulation/models.py +++ b/strands_robots/simulation/models.py @@ -73,7 +73,14 @@ def __post_init__(self) -> None: @dataclass class SimCamera: - """A camera in the simulation.""" + """A camera in the simulation. + + ``origin_robot`` (post-PR #85): when the camera was discovered inside a + robot's URDF during ``add_robot``, this is set to the robot's name so the + scene builder knows NOT to re-add the camera at the top level (it'll be + re-introduced via ``spec.attach(robot_spec)``). For user-added cameras + (via the ``add_camera`` tool action) this stays empty. + """ name: str position: list[float] = field(default_factory=lambda: [1.0, 1.0, 1.0]) @@ -82,6 +89,7 @@ class SimCamera: width: int = 640 height: int = 480 camera_id: int = -1 + origin_robot: str = "" @dataclass @@ -104,14 +112,14 @@ class SimWorld: escape hatches, each with a distinct role so backend implementers know which to use: - * ``_model``: the physics engine's **core model handle** — the single + * ``_model``: the physics engine's **core model handle** - the single compiled/loaded representation of the scene (e.g. ``mujoco.MjModel``, Isaac's ``Scene``, PyBullet's body registry). Every backend has one. - * ``_data``: the physics engine's **core simulation state handle** — + * ``_data``: the physics engine's **core simulation state handle** - the mutable per-step state companion to ``_model`` (e.g. ``mujoco.MjData``, Isaac's ``World``). Every backend has one. * ``_backend_state``: a **catch-all dict** for everything else the - backend needs to persist — generated XML, temp dirs, recording + backend needs to persist - generated XML, temp dirs, recording buffers, caches, etc. Prefer this over adding new fields here. All three are typed ``Any``/``dict`` so nothing leaks engine-specific @@ -127,7 +135,7 @@ class SimWorld: status: SimStatus = SimStatus.IDLE sim_time: float = 0.0 step_count: int = 0 - # Engine core handles — set after the backend builds the world. + # Engine core handles - set after the backend builds the world. # Use these for the primary model/state objects only; put everything # else in ``_backend_state`` below. _model: Any = None # Engine-specific model handle (e.g. MjModel, Scene) @@ -138,6 +146,6 @@ class SimWorld: # Prefer this over adding new fields to ``SimWorld``. _backend_state: dict[str, Any] = field(default_factory=dict) # Physics state checkpoints (used by save_state/restore_state in PR #85). - # Kept as a top-level field — requested by @yinsong1986 during review to + # Kept as a top-level field - requested by @yinsong1986 during review to # avoid monkey-patching when ``reset()`` creates a fresh ``SimWorld``. _checkpoints: dict[str, Any] = field(default_factory=dict) diff --git a/strands_robots/simulation/mujoco/__init__.py b/strands_robots/simulation/mujoco/__init__.py new file mode 100644 index 0000000..03c6a03 --- /dev/null +++ b/strands_robots/simulation/mujoco/__init__.py @@ -0,0 +1,32 @@ +"""MuJoCo simulation backend for strands-robots. + +CPU-based physics with offscreen rendering. No GPU required. +Supports URDF/MJCF loading, multi-robot scenes, policy execution, +domain randomization, and LeRobotDataset recording. + +Usage:: + + from strands_robots.simulation.mujoco import MuJoCoSimulation + + sim = MuJoCoSimulation(tool_name="my_sim") + sim.create_world() + sim.add_robot("so100", data_config="so100") + sim.run_policy("so100", policy_provider="mock", instruction="wave") + +Or via the top-level alias:: + + from strands_robots.simulation import Simulation # → MuJoCoSimulation +""" + +__all__ = [ + "MuJoCoSimulation", +] + + +def __getattr__(name: str) -> "type": + if name == "MuJoCoSimulation": + from strands_robots.simulation.mujoco.simulation import Simulation as _Sim + + globals()["MuJoCoSimulation"] = _Sim + return _Sim + raise AttributeError(f"module 'strands_robots.simulation.mujoco' has no attribute {name!r}") diff --git a/strands_robots/simulation/mujoco/backend.py b/strands_robots/simulation/mujoco/backend.py new file mode 100644 index 0000000..fc0f0d4 --- /dev/null +++ b/strands_robots/simulation/mujoco/backend.py @@ -0,0 +1,156 @@ +"""MuJoCo lazy import and GL backend configuration.""" + +import ctypes +import logging +import os +import sys +from typing import Any + +logger = logging.getLogger(__name__) + +_mujoco = None +_mujoco_viewer = None + + +def _is_headless() -> bool: + """Detect if running in a headless environment (no display server). + + Returns True on Linux when no DISPLAY or WAYLAND_DISPLAY is set, + which means GLFW-based rendering will fail. + + Windows and macOS are always False because MuJoCo uses native + windowing backends (WGL on Windows, CGL on macOS) that support + offscreen rendering without X11/Wayland. The EGL/OSMesa fallback + is Linux-specific. + """ + if sys.platform != "linux": + return False + if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"): + return False + return True + + +def _configure_gl_backend() -> None: # noqa: C901 + """Auto-configure MuJoCo's OpenGL backend for headless environments. + + MuJoCo reads MUJOCO_GL at import time to select the OpenGL backend: + - "egl" → EGL (GPU-accelerated offscreen, requires libEGL + NVIDIA driver) + - "osmesa" → OSMesa (CPU software rendering, slower but always works) + - "glfw" → GLFW (default, requires X11/Wayland display server) + + This function MUST be called before `import mujoco`. Setting MUJOCO_GL + after import has no effect - the backend is locked at import time. + + Never overrides a user-set MUJOCO_GL value. + """ + if os.environ.get("MUJOCO_GL"): + logger.debug(f"MUJOCO_GL already set to '{os.environ['MUJOCO_GL']}', respecting user config") + return + + if not _is_headless(): + return + + # Headless Linux - probe for EGL first (GPU-accelerated), then fall back to OSMesa (CPU) + try: + ctypes.cdll.LoadLibrary("libEGL.so.1") + os.environ["MUJOCO_GL"] = "egl" + logger.info("Headless environment detected - using MUJOCO_GL=egl (GPU-accelerated offscreen)") + return + except OSError: + pass + + try: + ctypes.cdll.LoadLibrary("libOSMesa.so") + os.environ["MUJOCO_GL"] = "osmesa" + logger.info("Headless environment detected - using MUJOCO_GL=osmesa (CPU software rendering)") + return + except OSError: + pass + + logger.warning( + "Headless environment detected but neither EGL nor OSMesa found. " + "MuJoCo rendering will likely fail. Install one of:\n" + " GPU: apt-get install libegl1-mesa-dev (or NVIDIA driver provides libEGL)\n" + " CPU: apt-get install libosmesa6-dev\n" + "Then set: export MUJOCO_GL=egl (or osmesa)" + ) + + +def _ensure_mujoco() -> "Any": + """Lazy import MuJoCo to avoid hard dependency. + + Auto-configures the OpenGL backend for headless environments before + importing mujoco, since MUJOCO_GL must be set at import time. + + Uses require_optional() for consistent dependency management across + the strands-robots package. + """ + global _mujoco, _mujoco_viewer + if _mujoco is None: + _configure_gl_backend() + from strands_robots.utils import require_optional + + _mujoco = require_optional( + "mujoco", + pip_install="mujoco", + extra="sim-mujoco", + purpose="MuJoCo simulation", + ) + if _mujoco_viewer is None and not _is_headless(): + try: + import mujoco.viewer as viewer + + _mujoco_viewer = viewer + except ImportError: + pass + return _mujoco + + +_rendering_available: bool | None = None + + +def _can_render() -> bool: + """Check if MuJoCo offscreen rendering is available. + + Probes once by creating a minimal Renderer. Result is cached. + Returns False on headless environments without EGL/OSMesa. + + On headless Linux, if MUJOCO_GL is not set after _configure_gl_backend() + ran, it means neither EGL nor OSMesa is available. In that case the + default GLFW backend would be used, which calls glfw.init() → abort() + at the C level (SIGABRT), killing the entire process before Python can + catch the error. We short-circuit to False to avoid the fatal probe. + """ + global _rendering_available + if _rendering_available is not None: + return _rendering_available + + # Guard: on headless systems without an offscreen GL backend configured, + # mj.Renderer() will use GLFW which triggers a C-level abort (SIGABRT). + # Skip the probe entirely - rendering is impossible anyway. + if _is_headless() and not os.environ.get("MUJOCO_GL"): + _rendering_available = False + logger.warning( + "Headless environment without EGL/OSMesa - rendering disabled. " + "Physics and joint observations will still work. " + "Install libegl1-mesa-dev or libosmesa6-dev for camera rendering." + ) + return False + + mj = _ensure_mujoco() + try: + model = mj.MjModel.from_xml_string("") + renderer = mj.Renderer(model, height=1, width=1) + renderer.close() + del renderer + _rendering_available = True + logger.info("MuJoCo rendering available") + except Exception as e: + _rendering_available = False + logger.warning( + "MuJoCo rendering unavailable: %s. " + "Physics/policy will work, but render/camera observations will be skipped. " + "Install EGL or OSMesa for offscreen rendering.", + e, + ) + return _rendering_available diff --git a/strands_robots/simulation/mujoco/physics.py b/strands_robots/simulation/mujoco/physics.py new file mode 100644 index 0000000..b1ad4f8 --- /dev/null +++ b/strands_robots/simulation/mujoco/physics.py @@ -0,0 +1,1158 @@ +"""Physics mixin - advanced MuJoCo physics introspection and manipulation. + +Exposes the deep MuJoCo C API through clean Python methods: +- Raycasting (mj_ray) +- Jacobians (mj_jacBody, mj_jacSite, mj_jacGeom) +- Energy computation (mj_energyPos, mj_energyVel) +- External forces (mj_applyFT, xfrc_applied) +- Mass matrix (mj_fullM) +- State checkpointing (mj_getState, mj_setState) +- Inverse dynamics (mj_inverse) +- Body/joint introspection (poses, velocities, accelerations) +- Direct joint position/velocity control (qpos, qvel) +- Runtime model modification (mass, friction, color, size) +- Sensor readout (sensordata) +- Contact force analysis (mj_contactForce) +""" + +import logging +from typing import TYPE_CHECKING, Any + +import numpy as np + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class PhysicsMixin: + """Advanced MuJoCo physics capabilities mixed into ``Simulation``. + + Lives at roughly ``self._world._data`` + ``self._world._model`` level: + reads/writes MuJoCo arrays directly for checkpointing, raycasts, + jacobians, joint control, sensor readout, etc. + + **Coupling** (see simulation.py top-level docstring): mixin reaches + into ``self._world``, ``self._lock``, and the host's + ``_require_no_running_policy`` / ``_require_world`` / ``_prune_done_futures`` + helpers. ``TYPE_CHECKING`` stubs below exist so mypy accepts those + lookups; they are a documentary contract, not an enforceable protocol. + + Naming: methods match action names in tool_spec.json for direct dispatch. + """ + + if TYPE_CHECKING: + import threading + + from strands_robots.simulation.models import SimWorld + + _lock: "threading.Lock" + _world: "SimWorld | None" + + def _require_no_running_policy( + self, action_name: str, robot_name: str | None = None + ) -> dict[str, Any] | None: ... + def _require_world(self) -> dict[str, Any] | None: ... + + # State Checkpointing + + def save_state(self, name: str = "default") -> dict[str, Any]: + """Save the full physics state (qpos, qvel, act, time) to a named checkpoint. + + Uses mj_getState with mjSTATE_PHYSICS for complete state capture. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + with self._lock: + state_size = mj.mj_stateSize(model, mj.mjtState.mjSTATE_PHYSICS) + state = np.zeros(state_size) + mj.mj_getState(model, data, state, mj.mjtState.mjSTATE_PHYSICS) + + if not hasattr(self._world, "_checkpoints"): + self._world._checkpoints = {} + + self._world._checkpoints[name] = { + "state": state.copy(), + "sim_time": self._world.sim_time, + "step_count": self._world.step_count, + } + + return { + "status": "success", + "content": [ + { + "text": ( + f"💾 State '{name}' saved\n" + f" t={self._world.sim_time:.4f}s, step={self._world.step_count}\n" + f"State vector: {state_size} floats\n" + f"Checkpoints: {list(self._world._checkpoints.keys())}" + ) + } + ], + } + + def load_state(self, name: str = "default") -> dict[str, Any]: + """Restore physics state from a named checkpoint.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # load_state during a running policy races worker thread + if err := self._require_no_running_policy("load_state"): + return err + + checkpoints = getattr(self._world, "_checkpoints", {}) + if name not in checkpoints: + available = list(checkpoints.keys()) if checkpoints else ["none"] + return { + "status": "error", + "content": [{"text": f"Checkpoint '{name}' not found. Available: {available}"}], + } + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + checkpoint = checkpoints[name] + + with self._lock: + mj.mj_setState(model, data, checkpoint["state"], mj.mjtState.mjSTATE_PHYSICS) + mj.mj_forward(model, data) + + self._world.sim_time = checkpoint["sim_time"] + self._world.step_count = checkpoint["step_count"] + + return { + "status": "success", + "content": [ + {"text": f"📂 State '{name}' restored (t={self._world.sim_time:.4f}s, step={self._world.step_count})"} + ], + } + + # External Forces + + def apply_force( + self, + body_name: str, + force: list[float] | None = None, + torque: list[float] | None = None, + point: list[float] | None = None, + ) -> dict[str, Any]: + """Apply an external force and/or torque to a body (latched). + + Uses mj_applyFT for precise force application at a world-frame point. + The force is latched in ``qfrc_applied`` and applied on every + subsequent ``mj_step`` until overwritten by the next ``apply_force`` + call. Each call zeroes the buffer first (replacing, not accumulating). + + To stop the force: ``apply_force(body, force=[0, 0, 0])``. + + Args: + body_name: Target body name. + force: [fx, fy, fz] in world frame (Newtons). + torque: [tx, ty, tz] in world frame (N·m). + point: [px, py, pz] world-frame point of force application. + Defaults to body CoM if not specified. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # apply_force during a running policy races worker thread + if err := self._require_no_running_policy("apply_force"): + return err + + # must supply at least one non-zero force or torque + if force is None and torque is None: + return { + "status": "error", + "content": [{"text": "apply_force: specify at least one of 'force' or 'torque' (non-zero vector)."}], + } + + # Validate vector lengths before hitting numpy + for _name, _vec in (("force", force), ("torque", torque), ("point", point)): + if _vec is not None: + try: + if len(_vec) != 3: + return { + "status": "error", + "content": [ + {"text": f"apply_force: '{_name}' must be a 3-element vector [x,y,z], got {len(_vec)}"} + ], + } + except TypeError: + return { + "status": "error", + "content": [{"text": f"apply_force: '{_name}' must be a list/tuple of 3 numbers"}], + } + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + body_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + + f = np.array(force or [0, 0, 0], dtype=np.float64) + t = np.array(torque or [0, 0, 0], dtype=np.float64) + # Note: explicit [0,0,0] is a valid "clear the latched force" command; we only + # reject the case where the caller forgot both args (handled above). + p = np.array(point, dtype=np.float64) if point else data.xipos[body_id].copy() + + # Zero the buffer first so calls are idempotent (replace, not accumulate). + # NOTE: MuJoCo does NOT reset qfrc_applied in mj_step - the force + # persists on every subsequent step until the next apply_force call. + with self._lock: + data.qfrc_applied[:] = 0.0 + mj.mj_applyFT(model, data, f, t, p, body_id, data.qfrc_applied) + + return { + "status": "success", + "content": [ + { + "text": ( + f"💨 Force applied to '{body_name}' (body {body_id})\n" + f"Force: {f.tolist()} N\n" + f"Torque: {t.tolist()} N·m\n" + f"Point: {p.tolist()}" + ) + } + ], + } + + # Raycasting + + def _resolve_mj_name(self, obj_type: int, name: str) -> int: + """Look up a MuJoCo name, tolerating robot namespacing. + + For physics/introspection methods that accept raw body/joint/site + names (``get_body_state("gripper")`` etc.), we try the name + verbatim first, then fall back to trying it prefixed with every + robot's namespace. This preserves the pre-namespacing UX for + single-robot scenes while still working in multi-robot scenes + when the name is unambiguous. + + In multi-robot scenes where multiple robots contain a body with + the same short name (e.g. two so101s each having ``gripper``), + the caller MUST pass the namespaced form (``arm0/gripper``) to + disambiguate. The fallback returns the first match it finds, + which is non-deterministic - this is a deliberate + "unambiguous or explicit" contract. + """ + import mujoco as _mj + + assert self._world is not None and self._world._model is not None + model = self._world._model + mid = _mj.mj_name2id(model, obj_type, name) + if mid >= 0: + return int(mid) + if "/" in name: # already namespaced, no point retrying + return -1 + for robot in self._world.robots.values(): + if robot.namespace: + mid = _mj.mj_name2id(model, obj_type, robot.namespace + name) + if mid >= 0: + return int(mid) + return -1 + + def raycast( + self, + origin: list[float], + direction: list[float], + exclude_body: int = -1, + include_static: bool = True, + ) -> dict[str, Any]: + """Cast a ray and find the first geom intersection. + + Uses mj_ray for precise distance sensing / obstacle detection. + + Args: + origin: [x, y, z] ray start point in world frame. + direction: [dx, dy, dz] ray direction (auto-normalized). + exclude_body: Body ID to exclude from intersection (-1 = none). + include_static: Whether to include static geoms. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + # validate vector shapes and reject zero-direction (mj_ray aborts the process on len=0) + try: + if len(origin) != 3: + return { + "status": "error", + "content": [{"text": f"raycast: 'origin' must be 3 elements [x,y,z], got {len(origin)}"}], + } + if len(direction) != 3: + return { + "status": "error", + "content": [{"text": f"raycast: 'direction' must be 3 elements [dx,dy,dz], got {len(direction)}"}], + } + except TypeError: + return { + "status": "error", + "content": [{"text": "raycast: 'origin' and 'direction' must be lists of 3 numbers"}], + } + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + pnt = np.array(origin, dtype=np.float64) + vec = np.array(direction, dtype=np.float64) + # Normalize direction + norm = np.linalg.norm(vec) + if norm < 1e-10: + return { + "status": "error", + "content": [{"text": "raycast: 'direction' vector is zero-length - supply a non-zero direction."}], + } + vec = vec / norm + + geomid = np.array([-1], dtype=np.int32) + dist = mj.mj_ray( + model, + data, + pnt, + vec, + None, # geom group filter (None = all) + 1 if include_static else 0, + exclude_body, + geomid, + ) + + hit = dist >= 0 + geom_name = None + if hit and geomid[0] >= 0: + geom_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, geomid[0]) + + result = { + "hit": hit, + "distance": float(dist) if hit else None, + "geom_id": int(geomid[0]) if hit else None, + "geom_name": geom_name, + "hit_point": (pnt + vec * dist).tolist() if hit else None, + } + + if hit: + text = f"🎯 Ray hit '{geom_name or geomid[0]}' at dist={dist:.4f}m, point={result['hit_point']}" + else: + text = "🎯 Ray: no intersection" + + return {"status": "success", "content": [{"text": text}, {"json": result}]} + + # Jacobians + + def get_jacobian( + self, + body_name: str | None = None, + site_name: str | None = None, + geom_name: str | None = None, + ) -> dict[str, Any]: + """Compute the Jacobian (position + rotation) for a body, site, or geom. + + The Jacobian maps joint velocities to Cartesian velocities: + v = J @ dq + + Returns both positional (3×nv) and rotational (3×nv) Jacobians. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + jacp = np.zeros((3, model.nv)) + jacr = np.zeros((3, model.nv)) + + if body_name: + obj_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_BODY, body_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + mj.mj_jacBody(model, data, jacp, jacr, obj_id) + label = f"body '{body_name}'" + elif site_name: + obj_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_SITE, site_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"Site '{site_name}' not found."}]} + mj.mj_jacSite(model, data, jacp, jacr, obj_id) + label = f"site '{site_name}'" + elif geom_name: + obj_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_GEOM, geom_name) + if obj_id < 0: + return {"status": "error", "content": [{"text": f"Geom '{geom_name}' not found."}]} + mj.mj_jacGeom(model, data, jacp, jacr, obj_id) + label = f"geom '{geom_name}'" + else: + return {"status": "error", "content": [{"text": "Specify body_name, site_name, or geom_name."}]} + + return { + "status": "success", + "content": [ + {"text": f"🧮 Jacobian for {label}: pos={jacp.shape}, rot={jacr.shape}, nv={model.nv}"}, + {"json": {"jacp": jacp.tolist(), "jacr": jacr.tolist(), "nv": model.nv}}, + ], + } + + # Energy + + def get_energy(self) -> dict[str, Any]: + """Compute potential and kinetic energy of the system.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_energyPos(model, data) + mj.mj_energyVel(model, data) + + potential = float(data.energy[0]) + kinetic = float(data.energy[1]) + total = potential + kinetic + + return { + "status": "success", + "content": [ + {"text": f"⚡ Energy: potential={potential:.4f}J, kinetic={kinetic:.4f}J, total={total:.4f}J"}, + {"json": {"potential": potential, "kinetic": kinetic, "total": total}}, + ], + } + + # Mass Matrix + + def get_mass_matrix(self) -> dict[str, Any]: + """Compute the full mass (inertia) matrix M(q). + + M is nv×nv where nv is the number of DoFs. + Useful for dynamics analysis, impedance control, etc. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + # data.qM is only valid after a forward pass; running mj_forward + # ensures the mass matrix reflects the current qpos (e.g. right after + # a reset/load_state). + mj.mj_forward(model, data) + nv = model.nv + M = np.zeros((nv, nv)) + if nv > 0: + mj.mj_fullM(model, M, data.qM) + rank = int(np.linalg.matrix_rank(M)) + cond = float(np.linalg.cond(M)) if rank > 0 else float("inf") + else: + # Empty scene (no DOFs yet) - return a well-typed zero payload + # instead of crashing in numpy on the empty matrix. + rank = 0 + cond = float("inf") + + return { + "status": "success", + "content": [ + {"text": f"🧮 Mass matrix: {nv}×{nv}, rank={rank}, cond={cond:.2e}"}, + { + "json": { + "shape": [nv, nv], + "rank": rank, + "condition_number": cond, + "diagonal": np.diag(M).tolist(), + "total_mass": float(np.sum(model.body_mass)), + } + }, + ], + } + + # Inverse Dynamics + + def inverse_dynamics(self) -> dict[str, Any]: + """Compute inverse dynamics: given qacc, what forces are needed? + + Runs mj_inverse to compute qfrc_inverse - the generalized forces + that would produce the current accelerations. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_inverse(model, data) + + # Build named force mapping + forces = {} + for i in range(model.njnt): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) + if name: + dof_adr = model.jnt_dofadr[i] + forces[name] = float(data.qfrc_inverse[dof_adr]) + + return { + "status": "success", + "content": [ + {"text": f"🔄 Inverse dynamics: {len(forces)} joint forces computed"}, + {"json": {"qfrc_inverse": forces}}, + ], + } + + # Body Introspection + + def get_body_state( + self, + body_name: str, + ) -> dict[str, Any]: + """Get the full state of a body: position, orientation, velocity, acceleration. + + Returns Cartesian pose + 6D spatial velocity (linear + angular). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + body_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + + # Position and orientation + pos = data.xpos[body_id].tolist() + quat = data.xquat[body_id].tolist() + rotmat = data.xmat[body_id].reshape(3, 3).tolist() + + # Velocity (6D: angular then linear in world frame) + vel = np.zeros(6) + mj.mj_objectVelocity(model, data, mj.mjtObj.mjOBJ_BODY, body_id, vel, 0) + linvel = vel[3:].tolist() + angvel = vel[:3].tolist() + + # Mass and inertia + mass = float(model.body_mass[body_id]) + com = data.xipos[body_id].tolist() + + state = { + "position": pos, + "quaternion": quat, + "rotation_matrix": rotmat, + "linear_velocity": linvel, + "angular_velocity": angvel, + "mass": mass, + "center_of_mass": com, + } + + text = ( + f"🏷️ Body '{body_name}' (id={body_id}):\n" + f" pos: [{pos[0]:.4f}, {pos[1]:.4f}, {pos[2]:.4f}]\n" + f" quat: [{quat[0]:.4f}, {quat[1]:.4f}, {quat[2]:.4f}, {quat[3]:.4f}]\n" + f" linvel: [{linvel[0]:.4f}, {linvel[1]:.4f}, {linvel[2]:.4f}]\n" + f" angvel: [{angvel[0]:.4f}, {angvel[1]:.4f}, {angvel[2]:.4f}]\n" + f" mass: {mass:.4f}kg, com: {com}" + ) + + return {"status": "success", "content": [{"text": text}, {"json": state}]} + + # Direct Joint Control + + def set_joint_positions( + self, + positions: dict[str, float] | list[float] | None = None, + robot_name: str | None = None, + ) -> dict[str, Any]: + """Set joint positions directly (bypassing actuators). + + Writes to qpos and runs mj_forward to update kinematics. + Useful for teleportation, IK solutions, or keyframe setting. + + Accepts EITHER form: + + * dict: {joint_name: value, ...} - explicit per-joint, safest in multi-robot scenes. + * list/tuple: [v0, v1, ...] - ordered positional. Must match a single robot's + joint count (when ``robot_name`` is given, that robot's joints; otherwise the + world must contain exactly one robot, or the call errors). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # mutating qpos under a running policy races mj_step + if err := self._require_no_running_policy("set_joint_positions"): + return err + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if positions is None: + return { + "status": "error", + "content": [{"text": "set_joint_positions: 'positions' is required (list or dict of joint values)."}], + } + + # normalize list input to dict using a deterministic joint ordering + ignored: list[str] = [] + if isinstance(positions, (list, tuple)): + robots = list(self._world.robots.values()) + if robot_name is not None: + robots = [r for r in robots if r.name == robot_name] + if not robots: + return {"status": "error", "content": [{"text": f"Robot '{robot_name}' not found."}]} + if len(robots) == 0: + return { + "status": "error", + "content": [ + { + "text": "set_joint_positions: list form requires a robot in the world; pass a dict instead, or add a robot first." + } + ], + } + if len(robots) > 1 and robot_name is None: + return { + "status": "error", + "content": [ + { + "text": f"set_joint_positions: list form is ambiguous with {len(robots)} robots; pass 'robot_name=' or use a dict." + } + ], + } + robot = robots[0] + joint_names = list(getattr(robot, "joint_names", []) or []) + if not joint_names: + # Fall back: enumerate joints that belong to this robot via namespace + ns = getattr(robot, "namespace", "") or "" + joint_names = [] + for jid in range(model.njnt): + jn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, jid) + if jn and (not ns or jn.startswith(ns)): + joint_names.append(jn) + if len(positions) != len(joint_names): + return { + "status": "error", + "content": [ + { + "text": ( + f"set_joint_positions: list length {len(positions)} does not match robot " + f"'{robot.name}' joint count {len(joint_names)}. Use a dict for partial updates." + ) + } + ], + } + positions = dict(zip(joint_names, positions, strict=True)) + elif not isinstance(positions, dict): + return { + "status": "error", + "content": [ + {"text": f"set_joint_positions: 'positions' must be a dict or list, got {type(positions).__name__}"} + ], + } + + set_count = 0 + with self._lock: + for jnt_name, value in positions.items(): + jnt_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + qpos_adr = model.jnt_qposadr[jnt_id] + data.qpos[qpos_adr] = float(value) + set_count += 1 + else: + ignored.append(jnt_name) + logger.warning("Joint '%s' not found, skipping", jnt_name) + + mj.mj_forward(model, data) + + msg = f"🎯 Set {set_count}/{len(positions)} joint positions, FK updated" + if ignored: + msg += f" (ignored: {ignored})" + return { + "status": "success", + "content": [{"text": msg}], + } + + def set_joint_velocities( + self, + velocities: dict[str, float] | list[float] | None = None, + robot_name: str | None = None, + ) -> dict[str, Any]: + """Set joint velocities directly. + + Writes to qvel. Useful for initializing dynamics. Accepts dict or list + (see set_joint_positions for list semantics). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if err := self._require_no_running_policy("set_joint_velocities"): + return err + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if velocities is None: + return { + "status": "error", + "content": [{"text": "set_joint_velocities: 'velocities' is required (list or dict)."}], + } + + ignored: list[str] = [] + if isinstance(velocities, (list, tuple)): + robots = list(self._world.robots.values()) + if robot_name is not None: + robots = [r for r in robots if r.name == robot_name] + if not robots: + return {"status": "error", "content": [{"text": f"Robot '{robot_name}' not found."}]} + if len(robots) == 0: + return { + "status": "error", + "content": [{"text": "set_joint_velocities: list form requires a robot in the world."}], + } + if len(robots) > 1 and robot_name is None: + return { + "status": "error", + "content": [ + { + "text": f"set_joint_velocities: list form is ambiguous with {len(robots)} robots; pass 'robot_name=' or use a dict." + } + ], + } + robot = robots[0] + joint_names = list(getattr(robot, "joint_names", []) or []) + if not joint_names: + ns = getattr(robot, "namespace", "") or "" + joint_names = [] + for jid in range(model.njnt): + jn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, jid) + if jn and (not ns or jn.startswith(ns)): + joint_names.append(jn) + if len(velocities) != len(joint_names): + return { + "status": "error", + "content": [ + { + "text": ( + f"set_joint_velocities: list length {len(velocities)} does not match robot " + f"'{robot.name}' joint count {len(joint_names)}. Use a dict for partial updates." + ) + } + ], + } + velocities = dict(zip(joint_names, velocities, strict=True)) + elif not isinstance(velocities, dict): + return { + "status": "error", + "content": [ + { + "text": f"set_joint_velocities: 'velocities' must be a dict or list, got {type(velocities).__name__}" + } + ], + } + + set_count = 0 + with self._lock: + for jnt_name, value in velocities.items(): + jnt_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + dof_adr = model.jnt_dofadr[jnt_id] + data.qvel[dof_adr] = float(value) + set_count += 1 + else: + ignored.append(jnt_name) + + msg = f"💨 Set {set_count}/{len(velocities)} joint velocities" + if ignored: + msg += f" (ignored: {ignored})" + return { + "status": "success", + "content": [{"text": msg}], + } + + # Sensor Readout + + def get_sensor_data(self, sensor_name: str | None = None) -> dict[str, Any]: + """Read sensor values from the simulation. + + MuJoCo supports: jointpos, jointvel, accelerometer, gyro, force, + torque, touch, rangefinder, framequat, subtreecom, clock, etc. + + Args: + sensor_name: Specific sensor name, or None for all sensors. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + if model.nsensor == 0: + # distinguish "no sensors at all" from "that specific sensor not found" + if sensor_name: + return { + "status": "error", + "content": [{"text": f"Sensor '{sensor_name}' not found. Model has no sensors."}], + } + return {"status": "success", "content": [{"text": "📡 No sensors in model."}]} + + mj.mj_forward(model, data) + + sensors = {} + for i in range(model.nsensor): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_SENSOR, i) + if not name: + name = f"sensor_{i}" + + adr = model.sensor_adr[i] + dim = model.sensor_dim[i] + values = data.sensordata[adr : adr + dim].tolist() + + if sensor_name and name != sensor_name: + continue + + sensors[name] = { + "values": values if dim > 1 else values[0], + "dim": int(dim), + "type": int(model.sensor_type[i]), + } + + if sensor_name and sensor_name not in sensors: + return {"status": "error", "content": [{"text": f"Sensor '{sensor_name}' not found."}]} + + lines = [f"📡 Sensors ({len(sensors)}/{model.nsensor}):"] + for name, info in sensors.items(): + lines.append(f"{name}: {info['values']} (dim={info['dim']})") + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"json": {"sensors": sensors}}], + } + + # Runtime Model Modification + + def set_body_properties( + self, + body_name: str, + mass: float | None = None, + ) -> dict[str, Any]: + """Modify body properties at runtime (no recompile needed). + + Changes take effect on the next mj_step. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if err := self._require_no_running_policy("set_body_properties"): + return err + + # mass must be > 0 (physics invariant) + if mass is not None: + try: + mass = float(mass) + except (TypeError, ValueError): + return { + "status": "error", + "content": [{"text": f"set_body_properties: 'mass' must be a positive number, got {mass!r}"}], + } + if mass <= 0: + return { + "status": "error", + "content": [{"text": f"set_body_properties: 'mass' must be > 0, got {mass}"}], + } + + mj = _ensure_mujoco() + model = self._world._model + body_id = self._resolve_mj_name(mj.mjtObj.mjOBJ_BODY, body_name) + if body_id < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + + changes = [] + with self._lock: + if mass is not None: + old_mass = float(model.body_mass[body_id]) + model.body_mass[body_id] = mass + changes.append(f"mass: {old_mass:.3f} → {mass:.3f}") + + return { + "status": "success", + "content": [{"text": f"🔧 Body '{body_name}': {', '.join(changes)}"}], + } + + def set_geom_properties( + self, + geom_name: str | None = None, + geom_id: int | None = None, + color: list[float] | None = None, + friction: list[float] | None = None, + size: list[float] | None = None, + ) -> dict[str, Any]: + """Modify geom properties at runtime (no recompile needed). + + Changes take effect immediately for rendering (color) or next step (friction, size). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if err := self._require_no_running_policy("set_geom_properties"): + return err + + mj = _ensure_mujoco() + model = self._world._model + + gid = geom_id + if geom_name: + gid = self._resolve_mj_name(mj.mjtObj.mjOBJ_GEOM, geom_name) + # our add_object pipeline names geoms as ``{object_name}_geom``. + # Accept the plain object name as a convenience alias. + if (gid is None or gid < 0) and not geom_name.endswith("_geom"): + gid = self._resolve_mj_name(mj.mjtObj.mjOBJ_GEOM, f"{geom_name}_geom") + if gid is None or gid < 0 or gid >= model.ngeom: + return {"status": "error", "content": [{"text": f"Geom '{geom_name or geom_id}' not found."}]} + + label = geom_name or f"geom_{gid}" + changes = [] + + with self._lock: + if color is not None: + model.geom_rgba[gid] = color[:4] if len(color) >= 4 else color[:3] + [1.0] + changes.append(f"color → {model.geom_rgba[gid].tolist()}") + + if friction is not None: + fric = friction[:3] if len(friction) >= 3 else friction + [0.0] * (3 - len(friction)) + model.geom_friction[gid] = fric + changes.append(f"friction → {fric}") + + if size is not None: + n = min(len(size), 3) + model.geom_size[gid, :n] = size[:n] + changes.append(f"size → {model.geom_size[gid].tolist()}") + + return { + "status": "success", + "content": [{"text": f"🔧 Geom '{label}': {', '.join(changes)}"}], + } + + # Contact Force Analysis + + def get_contact_forces(self) -> dict[str, Any]: + """Get detailed contact forces for all active contacts. + + Uses mj_contactForce for each active contact pair. + Returns normal and friction forces. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + contacts = [] + for i in range(data.ncon): + c = data.contact[i] + g1 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom1) or f"geom_{c.geom1}" + g2 = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, c.geom2) or f"geom_{c.geom2}" + + # Get contact force (normal + friction in contact frame) + force = np.zeros(6) + mj.mj_contactForce(model, data, i, force) + + contacts.append( + { + "geom1": g1, + "geom2": g2, + "distance": float(c.dist), + "position": c.pos.tolist(), + "normal_force": float(force[0]), + "friction_force": force[1:3].tolist(), + "full_wrench": force.tolist(), + } + ) + + if not contacts: + return {"status": "success", "content": [{"text": "💥 No active contacts."}]} + + lines = [f"💥 {len(contacts)} contacts:"] + for c in contacts[:15]: + lines.append(f"{c['geom1']} ↔ {c['geom2']}: normal={c['normal_force']:.3f}N, dist={c['distance']:.4f}m") + if len(contacts) > 15: + lines.append(f" ... and {len(contacts) - 15} more") + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"json": {"contacts": contacts}}], + } + + # Multi-Ray (batch raycasting) + + def multi_raycast( + self, + origin: list[float], + directions: list[list[float]], + exclude_body: int = -1, + ) -> dict[str, Any]: + """Cast multiple rays from a single origin (e.g., for LIDAR simulation). + + Efficiently casts N rays using individual mj_ray calls. + Returns array of distances and hit geoms. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + # validate origin shape; per-ray zero-direction guard (avoid mj_ray abort) + try: + if len(origin) != 3: + return { + "status": "error", + "content": [{"text": f"multi_raycast: 'origin' must be 3 elements [x,y,z], got {len(origin)}"}], + } + except TypeError: + return {"status": "error", "content": [{"text": "multi_raycast: 'origin' must be a list of 3 numbers"}]} + + pnt = np.array(origin, dtype=np.float64) + results: list[dict[str, Any]] = [] + + for idx, d in enumerate(directions): + try: + if len(d) != 3: + results.append( + { + "distance": None, + "geom_id": None, + "error": f"ray[{idx}]: direction must have 3 elements, got {len(d)}", + } + ) + continue + except TypeError: + results.append( + {"distance": None, "geom_id": None, "error": f"ray[{idx}]: direction must be a list of 3 numbers"} + ) + continue + vec = np.array(d, dtype=np.float64) + norm = np.linalg.norm(vec) + if norm < 1e-10: + results.append({"distance": None, "geom_id": None, "error": f"ray[{idx}]: zero-length direction"}) + continue + vec /= norm + geomid = np.array([-1], dtype=np.int32) + dist = mj.mj_ray(model, data, pnt, vec, None, 1, exclude_body, geomid) + results.append( + { + "distance": float(dist) if dist >= 0 else None, + "geom_id": int(geomid[0]) if dist >= 0 else None, + } + ) + + hit_count = sum(1 for r in results if r["distance"] is not None) + return { + "status": "success", + "content": [ + {"text": f"🎯 Multi-ray: {hit_count}/{len(directions)} hits from {origin}"}, + {"json": {"rays": results}}, + ], + } + + # Forward Kinematics (explicit) + + def forward_kinematics(self, body_name: str | None = None) -> dict[str, Any]: + """Run forward kinematics to update all body positions/orientations. + + Usually called implicitly by mj_step, but useful after manually + setting qpos to see updated Cartesian positions. + + If ``body_name`` is given, the response is filtered to that + single body (and errors cleanly if the body doesn't exist). + Otherwise returns every body as before. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + + mj.mj_kinematics(model, data) + mj.mj_comPos(model, data) + mj.mj_camlight(model, data) + + if body_name is not None: + bid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, body_name) + if bid < 0: + return {"status": "error", "content": [{"text": f"Body '{body_name}' not found."}]} + body_payload = { + "position": data.xpos[bid].tolist(), + "quaternion": data.xquat[bid].tolist(), + } + return { + "status": "success", + "content": [ + {"text": f"🦴 FK for '{body_name}': pos={body_payload['position']}"}, + {"json": {"body": body_name, **body_payload}}, + ], + } + + bodies = {} + for i in range(model.nbody): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + bodies[name] = { + "position": data.xpos[i].tolist(), + "quaternion": data.xquat[i].tolist(), + } + + return { + "status": "success", + "content": [ + {"text": f"🦴 FK computed for {model.nbody} bodies"}, + {"json": {"bodies": bodies}}, + ], + } + + # Total Mass + + def get_total_mass(self) -> dict[str, Any]: + """Get total mass and per-body mass breakdown.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model = self._world._model + + total = float(mj.mj_getTotalmass(model)) + bodies = {} + for i in range(model.nbody): + name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + m = float(model.body_mass[i]) + if m > 0: + bodies[name] = m + + return { + "status": "success", + "content": [ + {"text": f"⚖️ Total mass: {total:.4f}kg ({len(bodies)} bodies with mass)"}, + {"json": {"total_mass": total, "bodies": bodies}}, + ], + } + + # Export Model XML + + def export_xml(self, output_path: str | None = None) -> dict[str, Any]: + """Export the current model to MJCF XML. + + Uses mj_saveLastXML - exports the exact model currently loaded, + including any runtime modifications. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + + if output_path: + mj.mj_saveLastXML(output_path, self._world._model) + return {"status": "success", "content": [{"text": f"📄 Model exported to {output_path}"}]} + else: + # Return XML string via saveLastXML to temp file + import os + import tempfile + + with tempfile.NamedTemporaryFile(suffix=".xml", mode="w", delete=False) as tmp: + tmpfile = tmp.name + mj.mj_saveLastXML(tmpfile, self._world._model) + try: + with open(tmpfile) as f: + xml = f.read() + finally: + os.unlink(tmpfile) + return { + "status": "success", + "content": [ + {"text": f"📄 Model XML ({len(xml)} chars):\n{xml[:2000]}{'...' if len(xml) > 2000 else ''}"} + ], + } diff --git a/strands_robots/simulation/mujoco/randomization.py b/strands_robots/simulation/mujoco/randomization.py new file mode 100644 index 0000000..923e071 --- /dev/null +++ b/strands_robots/simulation/mujoco/randomization.py @@ -0,0 +1,137 @@ +"""Domain randomization mixin.""" + +import logging +from typing import TYPE_CHECKING, Any + +import numpy as np + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RandomizationMixin: + """Domain randomization mixed into ``Simulation``. + + Recolors geoms, perturbs lighting, and scales body mass / geom friction + by a random factor inside a user-supplied range. + + **Coupling** (see simulation.py top-level docstring): mixin reaches + into ``self._world``, ``self._lock``, and the host's + ``_require_no_running_policy`` / ``_require_world`` helpers. ``TYPE_CHECKING`` + stubs below exist so mypy accepts those lookups; they are a + documentary contract, not an enforceable protocol. + """ + + if TYPE_CHECKING: + import threading + + from strands_robots.simulation.models import SimWorld + + _lock: "threading.Lock" + _world: "SimWorld | None" + + def _require_no_running_policy( + self, action_name: str, robot_name: str | None = None + ) -> dict[str, Any] | None: ... + def _require_world(self) -> dict[str, Any] | None: ... + + def randomize( + self, + randomize_colors: bool = True, + randomize_lighting: bool = True, + randomize_physics: bool = False, + randomize_positions: bool = False, + position_noise: float = 0.02, + color_range: tuple[float, float] = (0.1, 1.0), + friction_range: tuple[float, float] = (0.5, 1.5), + mass_range: tuple[float, float] = (0.5, 2.0), + seed: int | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Apply domain randomization to the scene. + + Each flag is opt-in per-axis. Defaults: + - ``randomize_colors=True`` - geom RGB re-sampled in ``color_range``. + - ``randomize_lighting=True`` - light pos jittered ±0.5m, diffuse resampled. + - ``randomize_physics=False`` - friction/mass left untouched unless asked. + - ``randomize_positions=False`` - object qpos left untouched unless asked. + + "No flags" means "nothing is randomized" - the call is a no-op. This + matches the LLM ergonomics principle: explicit is better than implicit. + Randomization IS destructive (writes to ``model.geom_*`` / ``body_*`` + arrays and to ``data.qpos``); recompile the scene to undo. + + Args: + randomize_colors: Re-sample geom RGB values. + randomize_lighting: Jitter light positions + diffuse colour. + randomize_physics: Scale geom friction and body mass. + randomize_positions: Add uniform noise to dynamic-object xyz. + position_noise: Max ± xyz offset in meters when randomising positions. + color_range: (lo, hi) for uniform RGB sampling. + friction_range: (lo, hi) multiplicative scale on friction[0]. + mass_range: (lo, hi) multiplicative scale on body_mass. + seed: Optional np.random seed for reproducibility. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # domain randomization mutates model arrays; a running policy racing with it is UB + if err := self._require_no_running_policy("randomize"): + return err + + rng = np.random.default_rng(seed) + mj = _ensure_mujoco() + model = self._world._model + data = self._world._data + changes = [] + + with self._lock: + if randomize_colors: + for i in range(model.ngeom): + geom_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, i) + if geom_name and geom_name != "ground": + model.geom_rgba[i, :3] = rng.uniform(color_range[0], color_range[1], size=3) + changes.append(f"🎨 Colors: {model.ngeom} geoms randomized") + + if randomize_lighting: + for i in range(model.nlight): + model.light_pos[i] += rng.uniform(-0.5, 0.5, size=3) + model.light_diffuse[i] = rng.uniform(0.3, 1.0, size=3) + changes.append(f"💡 Lighting: {model.nlight} lights randomized") + + if randomize_physics: + friction_scales = {} + for i in range(model.ngeom): + gn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, i) or f"geom_{i}" + f = float(rng.uniform(*friction_range)) + model.geom_friction[i, 0] *= f + friction_scales[gn] = f + mass_scales = {} + for i in range(model.nbody): + if model.body_mass[i] > 0: + bn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, i) or f"body_{i}" + s = float(rng.uniform(*mass_range)) + model.body_mass[i] *= s + mass_scales[bn] = s + changes.append( + f"⚙️ Physics: {len(friction_scales)} geoms friction-scaled, {len(mass_scales)} bodies mass-scaled" + ) + changes.append(f" friction_scales={friction_scales}") + changes.append(f" mass_scales={mass_scales}") + + if randomize_positions: + for obj_name, obj in self._world.objects.items(): + if not obj.is_static: + jnt_name = f"{obj_name}_joint" + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + qpos_addr = model.jnt_qposadr[jnt_id] + noise = rng.uniform(-position_noise, position_noise, size=3) + data.qpos[qpos_addr : qpos_addr + 3] += noise + mj.mj_forward(model, data) + changes.append(f"📍 Positions: ±{position_noise}m noise on dynamic objects") + + return { + "status": "success", + "content": [{"text": "🎲 Domain Randomization applied:\n" + "\n".join(changes)}], + } diff --git a/strands_robots/simulation/mujoco/recording.py b/strands_robots/simulation/mujoco/recording.py new file mode 100644 index 0000000..0ba6c01 --- /dev/null +++ b/strands_robots/simulation/mujoco/recording.py @@ -0,0 +1,218 @@ +"""Recording mixin - start/stop trajectory recording to LeRobotDataset.""" + +import logging +import shutil +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RecordingMixin: + """Trajectory recording mixed into ``Simulation``. + + Writes per-step observations + actions + instruction to a LeRobotDataset + via ``start_recording`` / ``stop_recording`` and the ``on_frame`` hook + in ``PolicyRunner``. Separately from that, ``start_cameras_recording`` + dumps raw per-camera MP4s. + + **Coupling** (see simulation.py top-level docstring): mixin reaches + into ``self._world`` (trajectory buffer + dataset_recorder live in + ``_world._backend_state``). ``TYPE_CHECKING`` stub below exists so mypy + accepts the ``_world`` lookup; it is a documentary contract, not an + enforceable protocol. + """ + + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + + def start_recording( + self, + repo_id: str = "local/sim_recording", + task: str = "", + fps: int = 30, + root: str | None = None, + push_to_hub: bool = False, + vcodec: str = "libsvtav1", + overwrite: bool = False, + ) -> dict[str, Any]: + """Start recording to LeRobotDataset format (parquet + per-camera MP4). + + Requires the ``lerobot`` extra for the dataset schema. If you only + need plain MP4 video (no dataset schema, no policy-training metadata), + use :meth:`start_cameras_recording` - it runs under the + ``[sim-mujoco]`` extra alone (imageio-ffmpeg backend). + + Raises: + Friendly error when ``lerobot`` is not installed, directing the + caller to :meth:`start_cameras_recording` or to install the + optional extra. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + _DatasetRecorder: Any = None + _has_lerobot = False + try: + from strands_robots.dataset_recorder import DatasetRecorder as _DatasetRecorder + from strands_robots.dataset_recorder import has_lerobot_dataset as _check_lerobot + + _has_lerobot = _check_lerobot() + except ImportError: + pass + + if not _has_lerobot or _DatasetRecorder is None: + return { + "status": "error", + "content": [ + { + "text": ( + "start_recording produces a LeRobotDataset (parquet + video) and " + "requires the lerobot extra. For plain MP4 video under the " + "[sim-mujoco] extra alone, use start_cameras_recording instead.\n" + "\n" + " - Dataset + policy training data: pip install 'strands-robots[lerobot]'\n" + " - Plain MP4 only: start_cameras_recording(cameras=..., output_dir=...)" + ) + } + ], + } + + self._world._backend_state["recording"] = True + self._world._backend_state["trajectory"] = [] + self._world._backend_state["push_to_hub"] = push_to_hub + + try: + if overwrite: + if root: + dataset_dir = Path(root) + elif "/" not in repo_id or repo_id.startswith("/") or repo_id.startswith("./"): + dataset_dir = Path(repo_id) + else: + dataset_dir = Path.home() / ".cache" / "huggingface" / "lerobot" / repo_id + if dataset_dir.exists() and dataset_dir.is_dir(): + shutil.rmtree(dataset_dir) + logger.info("Removed existing dataset dir: %s", dataset_dir) + + # Collect joint names from every robot. When the scene contains + # more than one robot (e.g. multi-agent dual-task recording), prefix + # each joint with the robot's instance name (``alice__shoulder_pan``) + # so the dataset schema has unique joint ids per agent. Single-robot + # scenes keep the clean ``shoulder_pan`` names for backwards compat. + joint_names: list[str] = [] + camera_keys: list[str] = [] + robot_type = "unknown" + multi_robot = len(self._world.robots) > 1 + for rname, robot in self._world.robots.items(): + if multi_robot: + joint_names.extend(f"{rname}__{jn}" for jn in robot.joint_names) + else: + joint_names.extend(robot.joint_names) + robot_type = robot.data_config or rname + + mj = _ensure_mujoco() + for i in range(self._world._model.ncam): + cam_name = mj.mj_id2name(self._world._model, mj.mjtObj.mjOBJ_CAMERA, i) + if not cam_name: + continue + # LeRobot feature names can't contain '/' (reserved for + # nested-feature addressing). When a robot injects a + # namespaced camera (e.g. ``arm0/wrist_cam``), collapse + # the separator to ``__`` for the dataset schema. + safe_name = cam_name.replace("/", "__") + camera_keys.append(safe_name) + + assert _DatasetRecorder is not None # checked above + self._world._backend_state["dataset_recorder"] = _DatasetRecorder.create( + repo_id=repo_id, + fps=fps, + robot_type=robot_type, + joint_names=joint_names, + camera_keys=camera_keys, + task=task, + root=root, + vcodec=vcodec, + ) + return { + "status": "success", + "content": [ + { + "text": ( + f"Recording to LeRobotDataset: {repo_id}\n" + f"{len(joint_names)} joints, {len(camera_keys)} cameras @ {fps}fps\n" + f"Codec: {vcodec} | Task: {task or '(set per policy)'}\n" + f"Run policies to capture frames, then stop_recording to save episode" + ) + } + ], + } + except Exception as e: + self._world._backend_state["recording"] = False + logger.error("Dataset recorder init failed: %s", e) + return {"status": "error", "content": [{"text": f"Dataset init failed: {e}"}]} + + def stop_recording(self, output_path: str | None = None) -> dict[str, Any]: + """Stop recording and save episode to LeRobotDataset. + + idempotent - calling when not recording succeeds with a + 'Was not recording' message so callers can safely call it unconditionally. + """ + if self._world is None or not self._world._backend_state.get("recording", False): + return {"status": "success", "content": [{"text": "Was not recording."}]} + + self._world._backend_state["recording"] = False + recorder = self._world._backend_state.get("dataset_recorder", None) + + if recorder is None: + return {"status": "error", "content": [{"text": "No dataset recorder active."}]} + + recorder.save_episode() + push_result = None + if self._world._backend_state.get("push_to_hub", False): + push_result = recorder.push_to_hub(tags=["strands-robots", "sim"]) + + repo_id = recorder.repo_id + frame_count = recorder.frame_count + episode_count = recorder.episode_count + root = recorder.root + + recorder.finalize() + self._world._backend_state["dataset_recorder"] = None + self._world._backend_state["trajectory"] = [] + + text = ( + f"Episode saved to LeRobotDataset\n" + f"{repo_id} -- {frame_count} frames, {episode_count} episode(s)\n" + f"Local: {root}" + ) + if push_result and push_result.get("status") == "success": + text += "\nPushed to HuggingFace Hub" + + return {"status": "success", "content": [{"text": text}]} + + def get_recording_status(self) -> dict[str, Any]: + """Returns success in every lifecycle state (no world / not + recording / recording) with a distinguishing message so callers can + poll it unconditionally without try/except.""" + if self._world is None: + return { + "status": "success", + "content": [{"text": "⚪ No world - call create_world to start recording."}], + } + + recording = self._world._backend_state.get("recording", False) + steps = len(self._world._backend_state.get("trajectory", [])) + + if recording: + text = f"🔴 Recording: {steps} steps captured" + else: + text = f"⚪ Not recording (last episode: {steps} steps)" + + return { + "status": "success", + "content": [{"text": text}], + } diff --git a/strands_robots/simulation/mujoco/rendering.py b/strands_robots/simulation/mujoco/rendering.py new file mode 100644 index 0000000..4137275 --- /dev/null +++ b/strands_robots/simulation/mujoco/rendering.py @@ -0,0 +1,741 @@ +"""Rendering mixin - render, render_depth, get_contacts, observation helpers.""" + +import io +import logging +from typing import TYPE_CHECKING, Any + +from strands_robots.simulation.mujoco.backend import _can_render, _ensure_mujoco + +logger = logging.getLogger(__name__) + + +class RenderingMixin: + """Rendering + observation helpers mixed into ``Simulation``. + + Owns ``render``, ``render_depth``, ``render_all``, ``get_contacts``, and + the low-level ``_apply_sim_action`` (MuJoCo ``ctrl[]`` write + mj_step). + + **Coupling** (see simulation.py top-level docstring): mixin reaches + into ``self._world``, ``self._renderer_tls``, ``self._renderer_model``, + ``self.default_width`` / ``self.default_height``, ``self._lock`` and + ``self._viewer_handle``. ``TYPE_CHECKING`` stubs below exist so mypy + accepts those lookups; they are a documentary contract, not an + enforceable protocol. + + Thread-safety note: MuJoCo ``Renderer`` uses thread-local GL contexts + (CGL on macOS, GLX on Linux). A renderer created on thread A cannot be + reused from thread B - we keep one per-thread via ``_renderer_tls``. + """ + + if TYPE_CHECKING: + from strands_robots.simulation.models import SimWorld + + _world: "SimWorld | None" + + _renderer_model: Any + _renderer_tls: Any # threading.local() - per-thread renderer dict + default_width: int + default_height: int + + def _validate_render_dims(self, width: int, height: int) -> dict[str, Any] | None: + """reject non-positive render dims; convert MuJoCo's framebuffer + overflow to a plain-English message that tells the LLM the actual cap. + """ + if not isinstance(width, int) or not isinstance(height, int): + return { + "status": "error", + "content": [ + {"text": f"render: width/height must be int, got {type(width).__name__}/{type(height).__name__}."} + ], + } + if width <= 0 or height <= 0: + return { + "status": "error", + "content": [{"text": f"render: width and height must be > 0, got {width}x{height}."}], + } + if self._world is not None and self._world._model is not None: + max_w = int(getattr(self._world._model.vis.global_, "offwidth", 1280)) + max_h = int(getattr(self._world._model.vis.global_, "offheight", 960)) + if width > max_w or height > max_h: + return { + "status": "error", + "content": [ + { + "text": ( + f"render: requested {width}x{height} exceeds the offscreen " + f"framebuffer cap ({max_w}x{max_h}). Lower width/height or " + f"rebuild the model with a larger ." + ) + } + ], + } + return None + + def _get_renderer(self, width: int, height: int): + """Get a cached MuJoCo renderer, creating one only if needed. + + Returns None if rendering is unavailable (headless without EGL/OSMesa). + Callers must handle None return. + + Thread-safety: renderers are cached per-thread via ``threading.local`` + because ``mujoco.Renderer`` binds a GL context to the thread that + creates it (CGL on macOS, GLX on Linux). Sharing renderers across + threads would cause ``cgl.free()`` segfaults at cleanup time. + """ + if not _can_render(): + return None + mj = _ensure_mujoco() + assert self._world is not None # callers must check + + # Get or create per-thread renderer dict + renderers = getattr(self._renderer_tls, "renderers", None) + if renderers is None: + renderers = {} + self._renderer_tls.renderers = renderers + self._renderer_tls.model = None + + # Invalidate this thread's cache if model changed (e.g. after recompile) + if self._renderer_tls.model is not self._world._model: + renderers.clear() + self._renderer_tls.model = self._world._model + # Keep the per-instance marker for compatibility with any remaining + # read paths that checked self._renderer_model. + self._renderer_model = self._world._model + + key = (width, height) + if key not in renderers: + renderers[key] = mj.Renderer(self._world._model, height=height, width=width) + return renderers[key] + + def _get_sim_observation(self, robot_name: str, *, skip_images: bool = False) -> dict[str, Any]: + """Get observation from sim: joint state + cameras (unless skipped). + + Implements :meth:`SimEngine.get_observation`'s schema. + + Multi-robot note: when the injected robot XML was namespaced + (e.g. ``arm0/shoulder_pan`` in MuJoCo to allow multiple same-config + robots), we look up the prefixed MuJoCo name but return the short + name in the observation dict so the policy sees a stable, config-level + schema regardless of how many robots are in the scene. + """ + mj = _ensure_mujoco() + assert self._world is not None # callers must check + model, data = self._world._model, self._world._data + robot = self._world.robots[robot_name] + pfx = robot.namespace or "" + + obs = {} + for jnt_name in robot.joint_names: + # Try namespaced name first (multi-robot), fall back to raw. + lookup = pfx + jnt_name if pfx else jnt_name + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, lookup) + if jnt_id < 0 and pfx: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + obs[jnt_name] = float(data.qpos[model.jnt_qposadr[jnt_id]]) + + if skip_images: + return obs + + # Render every camera defined on the model plus any python-side cameras. + # Individual camera failures are logged but do not drop joint state. + cameras_to_render = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) for i in range(model.ncam)] + for pycam_name in self._world.cameras: + if pycam_name not in cameras_to_render: + cameras_to_render.append(pycam_name) + + for cname in cameras_to_render: + if not cname: + continue + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, cname) + cam_info = self._world.cameras.get(cname) + h = cam_info.height if cam_info else self.default_height + w = cam_info.width if cam_info else self.default_width + try: + renderer = self._get_renderer(w, h) + if renderer is None: + continue + if cam_id >= 0: + renderer.update_scene(data, camera=cam_id) + else: + renderer.update_scene(data) + obs[cname] = renderer.render().copy() + except (RuntimeError, ValueError) as e: + # Individual camera failure shouldn't stop joint state collection. + # Common cause: camera ID invalid after scene recompile. + logger.debug("Camera render failed for %s: %s", cname, e) + + return obs + + def _apply_sim_action(self, robot_name: str, action_dict: dict[str, Any], n_substeps: int = 1) -> None: + """Apply action dict to sim (same interface as robot.send_action). + + Multi-robot note: action keys are *short* names (e.g. ``shoulder_pan``). + We look up the namespaced MuJoCo actuator/joint name for this + specific ``robot_name`` so the same action dict routes to the right + physical actuator when multiple same-config robots exist. + """ + mj = _ensure_mujoco() + assert self._world is not None # callers must check + model, data = self._world._model, self._world._data + robot = self._world.robots.get(robot_name) + pfx = robot.namespace if robot else "" + + def _lookup(obj_type: Any, name: str) -> int: + """Try namespaced lookup first, fall back to raw.""" + if pfx: + i = mj.mj_name2id(model, obj_type, pfx + name) + if i >= 0: + return i + return int(mj.mj_name2id(model, obj_type, name)) + + for key, value in action_dict.items(): + act_id = _lookup(mj.mjtObj.mjOBJ_ACTUATOR, key) + if act_id >= 0: + data.ctrl[act_id] = float(value) + else: + # Fallback: key is a joint name - find the actuator that + # drives this joint via actuator_trnid (joint ID → actuator). + jnt_id = _lookup(mj.mjtObj.mjOBJ_JOINT, key) + if jnt_id >= 0: + for ai in range(model.nu): + if model.actuator_trnid[ai, 0] == jnt_id: + data.ctrl[ai] = float(value) + break + + for _ in range(max(1, n_substeps)): + mj.mj_step(model, data) + + assert self._world is not None + self._world.sim_time = data.time + assert self._world is not None # callers must check + self._world.step_count += n_substeps + + if hasattr(self, "_viewer_handle") and self._viewer_handle is not None: + self._viewer_handle.sync() + + def render( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: + """Render a camera view as base64 PNG image.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + # treat `None` as "use default", but `0` / negative values must + # still hit the validator (bool coercion would swallow them silently). + w = self.default_width if width is None else width + h = self.default_height if height is None else height + if err := self._validate_render_dims(w, h): + return err + + try: + renderer = self._get_renderer(w, h) + if renderer is None: + return { + "status": "error", + "content": [ + { + "text": ( + " Rendering unavailable (no OpenGL context). " + "Install EGL or OSMesa for offscreen rendering: " + "apt-get install libosmesa6-dev" + ) + } + ], + } + # strict camera validation - no silent fallback to default. + # Special 'default' / 'free' tokens route to the free camera; any + # other name MUST resolve or we error (prevents the LLM from + # believing it rendered viewpoint X while actually getting free-cam). + if camera_name in (None, "", "default", "free"): + cam_id = -1 + label = "free (default)" + else: + cam_id = mj.mj_name2id(self._world._model, mj.mjtObj.mjOBJ_CAMERA, camera_name) + if cam_id < 0: + return { + "status": "error", + "content": [ + {"text": f"Camera '{camera_name}' not found. Available: {self._list_camera_names()}"} + ], + } + label = camera_name + + if cam_id >= 0: + renderer.update_scene(self._world._data, camera=cam_id) + else: + renderer.update_scene(self._world._data) + + img = renderer.render().copy() + + from PIL import Image + + pil_img = Image.fromarray(img) + buffer = io.BytesIO() + pil_img.save(buffer, format="PNG") + png_bytes = buffer.getvalue() + + # summary stats so render_all can flag empty-looking frames + # without decoding the PNG a second time. + import numpy as _np + + pixel_var = float(_np.var(img)) + pixel_mean = float(_np.mean(img)) + + return { + "status": "success", + "content": [ + {"text": f"📸 {w}x{h} from '{label}' at t={self._world.sim_time:.3f}s"}, + {"image": {"format": "png", "source": {"bytes": png_bytes}}}, + {"json": {"pixel_variance": pixel_var, "pixel_mean": pixel_mean, "camera": label}}, + ], + } + except Exception as e: + return {"status": "error", "content": [{"text": f"Render failed: {e}"}]} + + def render_depth( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: + """Render depth map from a camera.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + # see note in render() re: None vs 0/negative. + w = self.default_width if width is None else width + h = self.default_height if height is None else height + if err := self._validate_render_dims(w, h): + return err + + try: + # strict camera validation (same policy as render()) + if camera_name in (None, "", "default", "free"): + cam_id = -1 + label = "free (default)" + else: + cam_id = mj.mj_name2id(self._world._model, mj.mjtObj.mjOBJ_CAMERA, camera_name) + if cam_id < 0: + return { + "status": "error", + "content": [ + {"text": f"Camera '{camera_name}' not found. Available: {self._list_camera_names()}"} + ], + } + label = camera_name + + renderer = self._get_renderer(w, h) + if renderer is None: + return { + "status": "error", + "content": [ + { + "text": ( + " Depth rendering unavailable (no OpenGL context). " + "Install EGL or OSMesa for offscreen rendering." + ) + } + ], + } + if cam_id >= 0: + renderer.update_scene(self._world._data, camera=cam_id) + else: + renderer.update_scene(self._world._data) + # MuJoCo prints a one-time ARB_clip_control warning on macOS + # when depth precision is reduced. Capture stderr on the first + # depth render so we can surface the warning in the response + # text (the LLM otherwise never hears about it). + clip_warn = getattr(self, "_depth_warn_text", None) + if clip_warn is None: + import contextlib as _ctx + import io as _io + import sys as _sys + + buf = _io.StringIO() + with _ctx.redirect_stderr(buf): + renderer.enable_depth_rendering() + depth = renderer.render() + renderer.disable_depth_rendering() + captured = buf.getvalue() + # Also forward to the real stderr so logs don't vanish. + if captured and _sys.__stderr__ is not None: + try: + _sys.__stderr__.write(captured) + except Exception: + pass + if "ARB_clip_control" in captured: + self._depth_warn_text = "⚠️ Depth accuracy limited on this GPU (missing ARB_clip_control)" + else: + self._depth_warn_text = "" + clip_warn = self._depth_warn_text + else: + renderer.enable_depth_rendering() + depth = renderer.render() + renderer.disable_depth_rendering() + + text = f"📸 Depth {w}x{h} from '{label}'\nMin: {float(depth.min()):.3f}m, Max: {float(depth.max()):.3f}m" + if clip_warn: + text += f"\n{clip_warn}" + return { + "status": "success", + "content": [ + {"text": text}, + {"json": {"depth_min": float(depth.min()), "depth_max": float(depth.max())}}, + ], + } + except Exception as e: + return {"status": "error", "content": [{"text": f"Depth render failed: {e}"}]} + + def _list_camera_names(self) -> list[str]: + """helper to list all camera names (model-defined + SimCamera aliases) + for error messages when an unknown camera_name is requested.""" + import mujoco as _mj + + names: list[str] = [] + if self._world is not None and self._world._model is not None: + for cid in range(self._world._model.ncam): + raw = _mj.mj_id2name(self._world._model, _mj.mjtObj.mjOBJ_CAMERA, cid) + if raw: + names.append(raw) + # Include SimCamera registry keys (may match model names; dedupe) + for k in self._world.cameras.keys() if self._world else (): + if k not in names: + names.append(k) + return names + + def get_contacts(self) -> dict[str, Any]: + """Return the list of active geom-geom contacts at the current step. + + We run ``mj_forward`` first so the contact list reflects the + current qpos/qvel even immediately after ``reset`` or ``add_robot`` + (without this, stale contacts from the previous step / uninitialised + memory can appear as phantom penetrations at t=0). + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = _ensure_mujoco() + model, data = self._world._model, self._world._data + # refresh contact list without advancing time. + mj.mj_forward(model, data) + + def _resolve_geom(gid: int) -> str: + """Prefer the geom name; fall back to its parent body name; then id.""" + gn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_GEOM, gid) + if gn: + return gn + # Walk to the parent body name. + try: + bid = int(model.geom_bodyid[gid]) + bn = mj.mj_id2name(model, mj.mjtObj.mjOBJ_BODY, bid) + if bn: + return f"{bn}/geom_{gid}" + except (IndexError, AttributeError): + pass + return f"geom_{gid}" + + contacts = [] + for i in range(data.ncon): + c = data.contact[i] + g1 = _resolve_geom(c.geom1) + g2 = _resolve_geom(c.geom2) + contacts.append({"geom1": g1, "geom2": g2, "dist": float(c.dist), "pos": c.pos.tolist()}) + + text = f"💥 {len(contacts)} contacts" if contacts else "No contacts." + if contacts: + for c in contacts[:10]: + text += f"\n • {c['geom1']} ↔ {c['geom2']} (d={c['dist']:.4f})" + + return { + "status": "success", + "content": [{"text": text}, {"json": {"contacts": contacts}}], + } + + # Multi-camera capture - Session recording for simulation + + # + # Design: + # - render_all(cameras=None, width=, height=) - single-shot snapshot + # of every camera at current sim_time. One PNG per camera. + # - start_cameras_recording(...) - daemon thread, one imageio writer + # per camera, appends frames at fps. + # - stop_cameras_recording() - flushes writers, returns paths + sizes. + # - get_cameras_recording_status() - frame counts, elapsed, per-cam. + # + # Thread safety: _get_renderer is thread-local (threading.local), so the + # background thread creates its own GL context. No shared state with + # main dispatch thread. + + def _active_camera_list(self, cameras): + """Resolve cameras=None to every camera currently in the world.""" + if self._world is None or self._world._model is None: + return [] + mj = _ensure_mujoco() + model = self._world._model + from_model = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) for i in range(model.ncam)] + from_model = [c for c in from_model if c] + py_side = list(self._world.cameras.keys()) if self._world else [] + all_cams = list(dict.fromkeys(from_model + py_side)) + if cameras is None: + return all_cams + missing = [c for c in cameras if c not in all_cams] + if missing: + logger.warning("Unknown camera(s) requested for capture: %s", missing) + return [c for c in cameras if c in all_cams] + + def render_all(self, cameras=None, width=None, height=None): + """Render every (or a subset of) camera in one call. + + Counterpart to ``render()`` for multi-view workflows - e.g. stereo, + overhead + wrist, or all cameras in a 4-view grid. Each camera ships + as its own ``{"image": {...}}`` block in the response. + + Args: + cameras: list of camera names; None = every camera. + width: per-camera width (defaults to camera's configured width). + height: per-camera height (same). + + Returns: + ``{"status", "content": [{"text": summary}, + {"text": "📸 cam1"}, {"image": {...}}, + {"text": "📸 cam2"}, {"image": {...}}, ...]}`` + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + names = self._active_camera_list(cameras) + if not names: + return {"status": "error", "content": [{"text": "No cameras in scene."}]} + content = [] + ok, failed = 0, 0 + low_var_warnings: list[str] = [] + for cam_name in names: + r = self.render(camera_name=cam_name, width=width, height=height) + if r.get("status") == "success": + ok += 1 + img_block = None + stats = None + for block in r.get("content", []): + if isinstance(block, dict): + if "image" in block and img_block is None: + img_block = block + if "json" in block and stats is None: + stats = block["json"] + if img_block is not None: + label = f"📸 {cam_name}" + # flag near-uniform frames (all black / all clear). + if stats and float(stats.get("pixel_variance", 99)) < 1.0: + warn = f"⚠️ camera '{cam_name}': image appears empty (variance < 1)" + label = f"{label} {warn}" + low_var_warnings.append(warn) + content.append({"text": label}) + content.append(img_block) + else: + failed += 1 + err = r.get("content", [{}])[0].get("text", "?") + content.append({"text": f"{cam_name}: {err}"}) + warn_suffix = f", {len(low_var_warnings)} low-variance" if low_var_warnings else "" + summary = ( + f"📸 Multi-camera snapshot at t={self._world.sim_time:.3f}s: " + f"{ok} ok, {failed} failed, {len(names)} requested{warn_suffix}" + ) + return { + "status": "success" if ok else "error", + "content": [{"text": summary}, *content], + } + + def start_cameras_recording( + self, + cameras=None, + output_dir=None, + fps=30, + width=None, + height=None, + name=None, + max_frames_per_camera=3000, + ): + """Start background capture of one ndarray buffer per camera. + + Strategy: the background thread collects raw RGB frames in memory + (one list per camera). ``stop_cameras_recording`` then flushes each + list to an MP4 on the main thread. This avoids a long-lived ffmpeg + subprocess pipe that would break under concurrent imageio writes + + policy-loop timing jitter. + + Memory cost: H*W*3 bytes * fps * duration * n_cams. For a 2s / 4-cam / + 320x240 / 15fps rollout: ~27 MB. Bounded by ``max_frames_per_camera``. + + Args: + cameras: list of camera names; None = every camera. + output_dir: where to write ``{tag}__{cam}.mp4``. + fps: capture rate. + width/height: per-frame size. + name: filename tag (auto if None). + max_frames_per_camera: safety cap on in-memory buffers. + """ + import os as _os + import threading as _threading + import time as _time + import uuid as _uuid + + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + if getattr(self, "_cams_rec_state", None) and self._cams_rec_state.get("running"): + cur = self._cams_rec_state["name"] + return { + "status": "error", + "content": [{"text": f"Already recording '{cur}'. Call stop_cameras_recording() first."}], + } + + names = self._active_camera_list(cameras) + if not names: + return {"status": "error", "content": [{"text": "No cameras to record."}]} + + out_dir = _os.path.abspath(output_dir or "/tmp/strands_robots/recordings") + _os.makedirs(out_dir, exist_ok=True) + tag = name or f"rec_{_uuid.uuid4().hex[:8]}" + + buffers = {cam: [] for cam in names} + paths = {cam: _os.path.join(out_dir, f"{tag}__{cam}.mp4") for cam in names} + + state = { + "running": True, + "name": tag, + "cameras": names, + "fps": fps, + "width": width, + "height": height, + "buffers": buffers, + "paths": paths, + "errors": dict.fromkeys(names, 0), + "output_dir": out_dir, + "started_at": _time.time(), + "thread": None, + "max_frames": max_frames_per_camera, + } + + def _loop(): + from strands_robots.simulation.policy_runner import _extract_frame_ndarray + + interval = 1.0 / fps + while state["running"]: + t0 = _time.time() + for cam in names: + if not state["running"]: + break + if len(state["buffers"][cam]) >= state["max_frames"]: + continue + try: + r = self.render(camera_name=cam, width=width, height=height) + arr = _extract_frame_ndarray(r) + if arr is not None: + state["buffers"][cam].append(arr) + else: + state["errors"][cam] += 1 + except Exception as e: + state["errors"][cam] += 1 + logger.debug("camera recorder (%s) error: %s", cam, e) + lag = _time.time() - t0 + if lag < interval: + _time.sleep(interval - lag) + + state["thread"] = _threading.Thread(target=_loop, daemon=True) + state["thread"].start() + self._cams_rec_state = state + + msg = ( + f"🎬 Recording {len(names)} camera(s) @ {fps} FPS → {out_dir}\n" + f" tag: {tag}\n" + f" cameras: {', '.join(names)}" + ) + return {"status": "success", "content": [{"text": msg}]} + + def stop_cameras_recording(self): + """Stop capture, flush buffers to MP4 on the MAIN thread. + + Runs ``imageio.get_writer``/``append_data``/``close`` here instead of + the recording thread so the ffmpeg pipe doesn't race with policy + timing jitter. Returns per-camera frame counts and paths. + """ + import os as _os + import time as _time + + state = getattr(self, "_cams_rec_state", None) + if not state or not state.get("running"): + # idempotent - 'already stopped' is a success, not an error. + return {"status": "success", "content": [{"text": "Was not recording cameras."}]} + + state["running"] = False + thread = state.get("thread") + if thread is not None: + thread.join(timeout=5.0) + + try: + import imageio.v2 as imageio + except ImportError: + return { + "status": "error", + "content": [{"text": "imageio not installed. pip install imageio imageio-ffmpeg"}], + } + + elapsed = _time.time() - state["started_at"] + lines = [ + f"🎬 Stopped '{state['name']}' after {elapsed:.1f}s", + f" output_dir: {state['output_dir']}", + ] + artifacts = [] + for cam in state["cameras"]: + frames_buffer = state["buffers"][cam] + path = state["paths"][cam] + errors = state["errors"][cam] + frames_written = 0 + size_kb = 0.0 + if frames_buffer: + writer = imageio.get_writer(path, fps=state["fps"], quality=8, macro_block_size=1) + try: + for arr in frames_buffer: + writer.append_data(arr) + frames_written += 1 + finally: + writer.close() + if _os.path.exists(path): + size_kb = _os.path.getsize(path) / 1024 + lines.append( + f" 📹 {cam:20s} {frames_written:>5d} frames {size_kb:>7.1f} KB " + f"({errors} errors) → {_os.path.basename(path)}" + ) + artifacts.append( + { + "camera": cam, + "path": path, + "frames": frames_written, + "errors": errors, + "size_kb": size_kb, + } + ) + + name = state["name"] + self._cams_rec_state = None + + return { + "status": "success", + "content": [ + {"text": "\n".join(lines)}, + {"json": {"recording": name, "artifacts": artifacts}}, + ], + } + + def get_cameras_recording_status(self): + """Cheap introspection of an ongoing multi-camera recording.""" + import time as _time + + state = getattr(self, "_cams_rec_state", None) + if not state or not state.get("running"): + return {"status": "success", "content": [{"text": "⚪ No active camera recording."}]} + + elapsed = _time.time() - state["started_at"] + lines = [f"🟢 Recording '{state['name']}' for {elapsed:.1f}s @ {state['fps']} FPS"] + for cam in state["cameras"]: + frames = len(state["buffers"][cam]) + lines.append(f" 📹 {cam:20s} {frames:>5d} frames ({state['errors'][cam]} errors)") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} diff --git a/strands_robots/simulation/mujoco/scene_ops.py b/strands_robots/simulation/mujoco/scene_ops.py new file mode 100644 index 0000000..c80336c --- /dev/null +++ b/strands_robots/simulation/mujoco/scene_ops.py @@ -0,0 +1,307 @@ +"""Scene mutation via the MuJoCo ``MjSpec`` AST. + +This module used to contain ~980 lines of XML-round-trip machinery (tmpdir + +``mj_saveLastXML`` + ElementTree parse + name-mangling + regex path patching). +All of that is replaced by ``spec.recompile(model, data)`` which: + +* preserves joint state on unchanged joints automatically, +* initializes new joints to body ``pos``/``quat`` (removing the need to + delete keyframes on freejoint insertion), +* namespaces robot bodies/joints/geoms/actuators/sensors via ``spec.attach()`` + without us walking the tree manually. + +Public API: + +* :func:`inject_robot_into_scene` - ``spec.attach(robot_spec, prefix=...)``. +* :func:`inject_object_into_scene` - ``SpecBuilder.add_object(spec, obj)`` + recompile. +* :func:`inject_camera_into_scene` - ``SpecBuilder.add_camera(spec, cam)`` + recompile. +* :func:`eject_body_from_scene` - ``SpecBuilder.remove_body(spec, name)`` + recompile. +* :func:`eject_robot_from_scene` - walk the spec, delete everything namespaced + under ``{robot_name}/``, then recompile. + +Every function takes a ``SimWorld`` whose ``_backend_state["spec"]`` holds the +live ``MjSpec``. They return ``True`` on success, ``False`` on failure (matching +the legacy API) so call sites in ``simulation.py`` don't need to change. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from strands_robots.simulation.models import SimCamera, SimObject, SimRobot, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco +from strands_robots.simulation.mujoco.spec_builder import SpecBuilder + +logger = logging.getLogger(__name__) + + +def _get_spec(world: SimWorld) -> Any | None: + """Fetch the live MjSpec from ``world._backend_state``. + + Callers MUST have run ``_compile_world`` at least once before any scene + mutation - without a spec we can't recompile. Returns ``None`` if missing + so callers can return a clean error dict rather than crashing mid-op. + """ + return world._backend_state.get("spec") + + +def _recompile_preserving_state(world: SimWorld, spec: Any) -> bool: + """Recompile ``spec`` in place, replacing ``world._model`` and ``_data``. + + Uses ``spec.recompile(model, data)`` which auto-preserves qpos/qvel for + existing joints and initializes new joints to their body's pos/quat. No + manual state-copy loop is required. + + Also re-discovers per-robot joint and actuator IDs (they may have shifted + as new bodies were inserted earlier in the body tree). Returns True on + success, False on compile failure (logged). + """ + mj = _ensure_mujoco() + try: + new_model, new_data = spec.recompile(world._model, world._data) + except (ValueError, RuntimeError) as e: + logger.error("spec.recompile failed: %s", e) + return False + + world._model = new_model + world._data = new_data + + # Keep the cached XML in sync with the spec for legacy readers (e.g. + # load_scene + add_robot round-trip). + try: + world._backend_state["xml"] = spec.to_xml() + except Exception as xml_err: + logger.debug("spec.to_xml() failed: %s", xml_err) + + # Re-discover per-robot IDs. Names inside MuJoCo are namespaced under + # robot.namespace (e.g. "arm1/shoulder_pan") when robots were attached + # via SpecBuilder.attach_robot; fall back to the raw name otherwise. + for robot in world.robots.values(): + pfx = robot.namespace or "" + robot.joint_ids = [] + robot.actuator_ids = [] + for jnt_name in robot.joint_names: + jid = -1 + if pfx: + jid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_JOINT, pfx + jnt_name) + if jid < 0: + jid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jid >= 0: + robot.joint_ids.append(jid) + for i in range(new_model.nu): + jnt_id = new_model.actuator_trnid[i, 0] + if jnt_id in robot.joint_ids: + robot.actuator_ids.append(i) + # Single-robot fallback: if no actuators matched by joint, assume + # all actuators belong to this robot. Matches the legacy behaviour. + if not robot.actuator_ids and len(world.robots) == 1: + robot.actuator_ids = list(range(new_model.nu)) + + return True + + +# ============================================================================= +# Inject +# ============================================================================= + + +def inject_robot_into_scene( + world: SimWorld, + robot: SimRobot, + robot_xml_path: str, +) -> bool: + """Attach a robot to the scene via ``spec.attach(other, prefix=..., frame=...)``. + + MuJoCo handles name prefixing (bodies, joints, geoms, actuators, sensors, + sites), asset deduplication (meshes, textures, materials), and default- + class namespacing. No manual tree-walking required. + + Registers the robot's source joint names on ``robot.joint_names`` so + downstream observation/policy code can resolve them via + ``{robot.namespace}{joint_name}``. + """ + spec = _get_spec(world) + if spec is None or world._model is None: + logger.error("inject_robot: no spec or model in world") + return False + + try: + joint_names = SpecBuilder.attach_robot(spec, robot, robot_xml_path) + robot.joint_names = joint_names + except (ValueError, RuntimeError, OSError) as e: + logger.error("Robot attach failed for '%s': %s", robot.name, e) + return False + + return _recompile_preserving_state(world, spec) + + +def inject_object_into_scene(world: SimWorld, obj: SimObject) -> bool: + """Add a ``SimObject`` to the scene and recompile in place.""" + spec = _get_spec(world) + if spec is None or world._model is None: + logger.error("inject_object: no spec or model in world") + return False + + try: + SpecBuilder.add_object(spec, obj) + except (ValueError, RuntimeError) as e: + logger.error("Object add failed for '%s': %s", obj.name, e) + return False + + return _recompile_preserving_state(world, spec) + + +def inject_camera_into_scene(world: SimWorld, cam: SimCamera) -> bool: + """Add a camera to the scene and recompile in place.""" + spec = _get_spec(world) + if spec is None or world._model is None: + logger.error("inject_camera: no spec or model in world") + return False + + try: + SpecBuilder.add_camera(spec, cam) + except (ValueError, RuntimeError) as e: + logger.error("Camera add failed for '%s': %s", cam.name, e) + return False + + return _recompile_preserving_state(world, spec) + + +# ============================================================================= +# Eject +# ============================================================================= + + +def eject_body_from_scene(world: SimWorld, body_name: str) -> bool: + """Remove a body (by short name) and recompile.""" + spec = _get_spec(world) + if spec is None or world._model is None: + logger.error("eject_body: no spec or model in world") + return False + + if not SpecBuilder.remove_body(spec, body_name): + logger.warning("Body '%s' not found in spec - nothing ejected", body_name) + # Matching legacy behaviour: return True so scene state stays consistent + # (caller has already popped the Python-side dict entry). + return True + + return _recompile_preserving_state(world, spec) + + +def eject_robot_from_scene(world: SimWorld, robot_name: str) -> bool: + """Remove every spec element namespaced under ``{robot_name}/``. + + Implementation note: deleting a body that was added via ``spec.attach()`` + triggers a known MuJoCo 3.8 segfault at interpreter shutdown (the + attached child spec's memory gets freed twice). To sidestep that bug + we REBUILD the scene spec from scratch using the post-remove + ``world.robots`` / ``world.objects`` / ``world.cameras`` state, then + re-attach the remaining robots. Joint state is not preserved across this + path - callers that care should call ``reset`` or save/restore state + around remove_robot. In the common case (agent removes a robot to clear + the scene), this is the expected behaviour anyway. + """ + spec = _get_spec(world) + if spec is None or world._model is None: + logger.error("eject_robot: no spec or model in world") + return False + + mj = _ensure_mujoco() + + # Preserve the current qpos for bodies that are NOT being removed. + # We rebuild from world state and then re-attach remaining robots, so + # object freejoints start at their body pos (matching fresh add_object + # semantics); robot joints start at qpos=0 (same as fresh add_robot). + + # First drop cameras that originated from the robot being ejected. + # They're in world.cameras with origin_robot == robot_name. Without this, + # SpecBuilder.build would skip them (via origin_robot), but stale entries + # would linger in the registry and confuse observation code. + stale_cam_names = [cname for cname, cam in world.cameras.items() if getattr(cam, "origin_robot", "") == robot_name] + for cname in stale_cam_names: + del world.cameras[cname] + + # Step 1: rebuild the base spec from world (objects + cameras + + # lights + ground). + new_spec = SpecBuilder.build(world) + + # Step 2: re-attach every remaining robot (the one being ejected is + # already popped from ``world.robots`` by the caller). + for robot in world.robots.values(): + # Re-discover joint names via the attach - they're stable per URDF. + joint_names = SpecBuilder.attach_robot(new_spec, robot, robot.urdf_path) + robot.joint_names = joint_names + + # Step 3: compile fresh and install. No spec.recompile(model, data) + # here - recompile implicitly preserves qpos state which doesn't + # make sense across a scene rebuild, and forcing a fresh compile + # avoids the attach/delete bug. + try: + new_model = new_spec.compile() + new_data = mj.MjData(new_model) + except (ValueError, RuntimeError) as e: + logger.error("eject_robot: fresh compile failed: %s", e) + return False + + world._model = new_model + world._data = new_data + world._backend_state["spec"] = new_spec + try: + world._backend_state["xml"] = new_spec.to_xml() + except Exception as xml_err: + logger.debug("spec.to_xml() failed: %s", xml_err) + + # Re-discover joint/actuator IDs for remaining robots. + for robot in world.robots.values(): + pfx = robot.namespace or "" + robot.joint_ids = [] + robot.actuator_ids = [] + for jnt_name in robot.joint_names: + jid = -1 + if pfx: + jid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_JOINT, pfx + jnt_name) + if jid < 0: + jid = mj.mj_name2id(new_model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jid >= 0: + robot.joint_ids.append(jid) + for i in range(new_model.nu): + jnt_id = new_model.actuator_trnid[i, 0] + if jnt_id in robot.joint_ids: + robot.actuator_ids.append(i) + if not robot.actuator_ids and len(world.robots) == 1: + robot.actuator_ids = list(range(new_model.nu)) + + logger.debug("eject_robot %r: scene rebuilt", robot_name) + return True + + +# ============================================================================= +# Agent-authored raw MJCF (Stage 6) +# ============================================================================= + + +def replace_scene_mjcf(world: SimWorld, xml: str) -> bool: + """Atomically swap the whole scene for agent-written MJCF. + + Validated by actually compiling it. On failure raises ``ValueError`` with + MuJoCo's compiler error verbatim. On success, the old spec/model/data are + replaced and all per-robot joint/actuator IDs re-discovered (but since + the agent may have changed the whole scene, the ``world.robots`` dict + is NOT touched - that's the caller's responsibility). + """ + mj = _ensure_mujoco() + new_spec = SpecBuilder.from_mjcf_string(xml) + # Compile eagerly so malformed XML fails here rather than on the next + # mj_step. + new_model = new_spec.compile() + new_data = mj.MjData(new_model) + + world._backend_state["spec"] = new_spec + world._model = new_model + world._data = new_data + try: + world._backend_state["xml"] = new_spec.to_xml() + except Exception: + pass + return True diff --git a/strands_robots/simulation/mujoco/simulation.py b/strands_robots/simulation/mujoco/simulation.py new file mode 100644 index 0000000..3ffe946 --- /dev/null +++ b/strands_robots/simulation/mujoco/simulation.py @@ -0,0 +1,1987 @@ +"""MuJoCo Simulation backend - AgentTool orchestrator + shared state host. + +Architecture notes (honest version, see GH #118) +------------------------------------------------ +The ``Simulation`` class uses multiple-inheritance to compose four mixins +(``PhysicsMixin``, ``RenderingMixin``, ``RecordingMixin``, ``RandomizationMixin``) +on top of the ``SimEngine`` ABC and the Strands ``AgentTool`` base. The +split keeps each file navigable (physics.py ~1150 lines, rendering.py ~730, +etc.) but the mixin boundaries describe *where code lives*, NOT the +coupling graph. + +Every mixin reaches back into this class for the same shared state: + + self._world - SimWorld handle (model + data + bookkeeping) + self._lock - serializes mj_step and ctrl[] writes + self._mj - cached ``mujoco`` module reference + self._policy_threads - per-robot Future dict (GH #114) + self._renderer_tls - thread-local renderer cache (macOS CGL) + self._executor - ThreadPoolExecutor for async policies + +AND the cross-cutting helpers: + + self._require_world() - "is the world live?" guard + self._require_no_running_policy() - scene-mutation safety gate + self._prune_done_futures() - cleanup of stale Future refs + self._active_policy_robots() - introspection + prune + +Mixins declare these via ``if TYPE_CHECKING`` stubs so mypy accepts the +attribute lookups. This is NOT a Protocol - mixins are not enforceable; +the contract is *documentary*. The stubs exist so edits to the helpers +in this file propagate to the mixin type-checks without manual sync. + +The alternative (extract a ``_SimulationState`` dataclass + pass it to +mixins) was explored and rejected: threading the state through every +method would blow up the diff across every mutation call, and mypy +narrowing of ``state.world._model`` after a ``_require_world(state)`` +call does not work any better than narrowing through a bound method +(same limitation that led commit f5c8518 to back out the helper-based +dedup). + +So: the split is honest about being for file-size, not for decoupling. +""" + +import inspect +import json +import logging +import os +import re +import threading +import time +from collections.abc import AsyncGenerator +from concurrent.futures import Future, ThreadPoolExecutor +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from strands.tools.tools import AgentTool +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolSpec, ToolUse + +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.model_registry import ( + list_available_models, + resolve_model, +) +from strands_robots.simulation.model_registry import ( + register_urdf as _register_urdf, +) +from strands_robots.simulation.models import SimCamera, SimObject, SimRobot, SimStatus, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco +from strands_robots.simulation.mujoco.physics import PhysicsMixin +from strands_robots.simulation.mujoco.randomization import RandomizationMixin +from strands_robots.simulation.mujoco.recording import RecordingMixin +from strands_robots.simulation.mujoco.rendering import RenderingMixin +from strands_robots.simulation.mujoco.scene_ops import ( + eject_body_from_scene, + eject_robot_from_scene, + inject_camera_into_scene, + inject_object_into_scene, + inject_robot_into_scene, + replace_scene_mjcf, +) +from strands_robots.simulation.mujoco.spec_builder import SpecBuilder +from strands_robots.simulation.policy_runner import CooperativeStop + +if TYPE_CHECKING: + from strands_robots.policies import Policy + +logger = logging.getLogger(__name__) + +_TOOL_SPEC_PATH = Path(__file__).parent / "tool_spec.json" + +# Tool schema is 357 lines of JSON. `tool_spec` property is on the LLM hot path +# (called on every `strands` invocation). Load once at import, not per access. +with open(_TOOL_SPEC_PATH) as _f: + _TOOL_SPEC_SCHEMA: dict[str, Any] = json.load(_f) + + +class Simulation( + PhysicsMixin, + RenderingMixin, + RecordingMixin, + RandomizationMixin, + SimEngine, + AgentTool, +): + """Programmatic MuJoCo simulation environment as a Strands AgentTool. + + Gives AI agents the ability to create, modify, and control MuJoCo + simulation environments through natural language → tool actions. + + **Stateful session.** One MuJoCo world per instance; actions form an + implicit state machine starting with ``create_world``. Tools that mutate + the scene (``add_robot``, ``remove_robot``, ``add_object``, ``remove_object``, ``move_object``, ``add_camera``, ``remove_camera``, + ``load_scene``) are NOT safe to call while a policy is running via + ``start_policy`` - stop it first. Call ``destroy()`` or ``cleanup()`` at + session end to release the ThreadPoolExecutor, temp dirs, and MuJoCo + resources. + """ + + def __init__( + self, + tool_name: str = "sim", + default_timestep: float = 0.002, + default_width: int = 640, + default_height: int = 480, + mesh: bool = True, + peer_id: str | None = None, + **kwargs, + ): + super().__init__() + self.tool_name_str = tool_name + self.default_timestep = default_timestep + self.default_width = default_width + self.default_height = default_height + + self._world: SimWorld | None = None + self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix=f"{tool_name}_sim") + # Per-robot Future refs for *active* policies. Completed futures are + # pruned by ``_active_policy_futures()``/``_prune_done_futures()`` so + # the dict never grows unboundedly and never reports stale "running". + self._policy_threads: dict[str, Future] = {} + self._shutdown_event = threading.Event() + # ``self._lock`` serializes writes to MuJoCo ``model``/``data`` arrays + # and calls to ``mj_step`` - MuJoCo physics is NOT safe for concurrent + # mutation from multiple threads. This lock is the single point that + # makes concurrent per-robot policies safe: + # + # * Two policies on different robots can run in parallel at the + # *inference* level (observation build, action compute). + # * When either policy calls ``send_action``, it serializes here + # briefly to write its own ``ctrl[]`` slots and advance physics. + # * ``mj_step`` advances the whole scene - so two robots sharing + # one world share one physics clock. That's correct: one tick of + # physical time advances all bodies. + self._lock = threading.Lock() + + self._viewer_handle = None + self._viewer_thread = None + + # Thread-local renderer cache - MuJoCo Renderer uses thread-local GL + # contexts (CGL on macOS, GLX on Linux). Sharing renderers across + # threads causes SIGSEGV in cgl.free(). Each thread gets its own. + self._renderer_tls = threading.local() + self._renderer_model = None + + # Fail fast: verify MuJoCo is importable at construction time + # so consumers catch missing-dependency errors immediately. + self._mj = _ensure_mujoco() + logger.info("🎮 Simulation tool '%s' initialized", tool_name) + + # Public Properties + + @property + def mj_model(self): + """Direct access to the MuJoCo model (mujoco.MjModel).""" + return self._world._model if self._world else None + + @property + def mj_data(self): + """Direct access to the MuJoCo data (mujoco.MjData).""" + return self._world._data if self._world else None + + # Robot-compatible interface + + def get_observation(self, robot_name: str | None = None, *, skip_images: bool = False) -> dict[str, Any]: + """Get full observation for a robot: joint state + all attached cameras. + + See :meth:`SimEngine.get_observation` for the schema contract. + """ + if self._world is None or self._world._model is None: + return {} + if robot_name is None: + if not self._world.robots: + return {} + robot_name = next(iter(self._world.robots)) + if robot_name not in self._world.robots: + return {} + if skip_images and self._world is not None and self._world._backend_state.get("recording"): + # T26: dataset recording needs every frame's image obs. Override + # the policy's skip hint when an active recorder is attached. + skip_images = False + return self._get_sim_observation(robot_name, skip_images=skip_images) + + def send_action(self, action: dict[str, Any], robot_name: str | None = None, n_substeps: int = 1) -> None: + """Apply action to simulation (Robot ABC compatible). + + Thread-safety: acquires self._lock around ctrl writes + mj_step, + as documented in base.py's SimEngine contract. Concurrent calls + from the agent's dispatch thread and a PolicyRunner worker are + serialized here. + """ + if self._world is None or self._world._model is None: + return + if robot_name is None: + if not self._world.robots: + return + robot_name = next(iter(self._world.robots)) + if robot_name not in self._world.robots: + return + with self._lock: + self._apply_sim_action(robot_name, action, n_substeps=n_substeps) + + # World Management + + def _cheap_robot_count(self) -> int: + try: + from strands_robots.registry import list_robots as _registry_list_robots + + return len(_registry_list_robots(mode="sim")) + except ImportError: + return 0 + + def create_world( + self, timestep: float | None = None, gravity: list[float] | None = None, ground_plane: bool = True + ) -> dict[str, Any]: + """Create a new simulation world.""" + # mujoco verified at __init__ + + if self._world is not None and self._world._model is not None: + return { + "status": "error", + "content": [{"text": "World already exists. Use action='destroy' first, or action='reset'."}], + } + + if gravity is None: + _gravity = [0.0, 0.0, -9.81] + elif isinstance(gravity, (int, float)): + _gravity = [0.0, 0.0, float(gravity)] + else: + _gravity = list(gravity) + + self._world = SimWorld( + timestep=timestep or self.default_timestep, + gravity=_gravity, + ground_plane=ground_plane, + ) + + self._world.cameras["default"] = SimCamera( + name="default", + position=[1.5, 1.5, 1.2], + target=[0.0, 0.0, 0.3], + width=self.default_width, + height=self.default_height, + ) + + self._compile_world() + + return { + "status": "success", + "content": [ + { + "text": ( + "🌍 Simulation world created\n" + f"⚙️ Timestep: {self._world.timestep}s ({1 / self._world.timestep:.0f}Hz physics)\n" + f"🌐 Gravity: {self._world.gravity}\n" + f"📷 Default camera ready\n" + f"🤖 Robot models: {self._cheap_robot_count()} available\n" + "💡 Add robots: action='add_robot' (urdf_path or data_config)\n" + "💡 Add objects: action='add_object'\n" + "💡 List URDFs: action='list_urdfs'" + ) + } + ], + } + + def load_scene(self, scene_path: str) -> dict[str, Any]: + """Load a complete scene from an MJCF XML (or URDF) file. + + Replaces the currently-live spec with one parsed from disk. The + loaded spec becomes the source of truth, so downstream + ``add_object`` / ``add_camera`` / ``add_robot`` calls mutate it via + ``spec.recompile(model, data)`` and preserve the on-disk scene. + + Notes: + + * ``_backend_state["scene_loaded"] = True`` stays as a marker for + introspection (and for downstream callers that still check it, + though the scene_ops path is now uniform across both entry + points). + * ``_backend_state["scene_base_dir"]`` is recorded in case any + consumer needs the original source directory (e.g. for mesh path + resolution in followup inject operations on files with relative + mesh paths). + """ + if err := self._require_no_running_policy("load_scene"): + return err + mj = self._mj + + if not os.path.exists(scene_path): + return {"status": "error", "content": [{"text": f"Scene file not found: {scene_path}"}]} + + try: + self._world = SimWorld() + # Load the scene as a live MjSpec - this gives us a mutable AST + # for downstream add_object/add_robot operations, matching the + # contract produced by _compile_world for fresh worlds. + spec = SpecBuilder.from_file(scene_path) + self._world._backend_state["spec"] = spec + self._world._model = spec.compile() + self._world._data = mj.MjData(self._world._model) + self._world.status = SimStatus.IDLE + + # Cache the canonical serialisation; legacy readers use this. + try: + self._world._backend_state["xml"] = spec.to_xml() + except Exception as xml_err: + logger.debug("spec.to_xml() on loaded scene failed: %s", xml_err) + + self._world._backend_state["scene_loaded"] = True + self._world._backend_state["scene_base_dir"] = os.path.dirname(os.path.abspath(scene_path)) + + return { + "status": "success", + "content": [ + { + "text": ( + f"🌍 Scene loaded from {os.path.basename(scene_path)}\n" + f"🦴 Bodies: {self._world._model.nbody}, 🔩 Joints: {self._world._model.njnt}, ⚡ Actuators: {self._world._model.nu}\n" + "💡 Use action='get_state' to inspect, action='step' to simulate" + ) + } + ], + } + except Exception as e: + logger.error("Failed to load scene: %s", e) + return {"status": "error", "content": [{"text": f"Failed to load scene: {e}"}]} + + def replace_scene_mjcf(self, xml: str) -> dict[str, Any]: + """Atomically replace the entire scene with agent-authored MJCF. + + Validated by actually compiling it via ``mujoco.MjSpec.from_string`` + and ``spec.compile()``. On failure returns a standard error dict with + MuJoCo's compiler error verbatim; on success the old ``_world._model``, + ``_world._data`` and ``_world._backend_state['spec']`` are replaced. + + Note: ``self._world.robots`` / ``objects`` / ``cameras`` registries + are LEFT UNTOUCHED. The raw MJCF can express elements that those + dataclasses can't (````, ````, ````, etc.) - + the agent is responsible for reconciling the registry with the new + scene if it cares. + + Use this as an escape hatch when the ``add_object`` / ``add_robot`` + vocabulary is insufficient. For additive changes, prefer those + methods - they keep the registry in sync. + """ + if self._world is None: + return {"status": "error", "content": [{"text": "No world. Use action='create_world' first."}]} + if err := self._require_no_running_policy("replace_scene_mjcf"): + return err + + try: + replace_scene_mjcf(self._world, xml) + except (ValueError, RuntimeError) as e: + return {"status": "error", "content": [{"text": f"MJCF compile failed: {e}"}]} + + model = self._world._model + return { + "status": "success", + "content": [ + { + "text": ( + f"🔄 Scene replaced via raw MJCF\n" + f"🦴 Bodies: {model.nbody}, 🔩 Joints: {model.njnt}, ⚡ Actuators: {model.nu}, 📷 Cameras: {model.ncam}\n" + "⚠️ world.robots / world.objects / world.cameras registries were NOT updated - " + "they describe our previous Python-side view of the scene." + ) + } + ], + } + + def _compile_world(self) -> None: + """Build the MjSpec from ``self._world`` and compile it to MjModel. + + Stashes the live ``MjSpec`` in ``_backend_state["spec"]`` so every + subsequent scene mutation uses ``spec.recompile(model, data)`` in + place - that preserves existing joint state automatically, replacing + the legacy XML-round-trip helpers in ``scene_ops.py``. + + Also exports ``spec.to_xml()`` to ``_backend_state["xml"]`` for any + consumer that still reads the raw MJCF string (e.g. ``load_scene`` + compatibility paths). + """ + mj = self._mj + assert self._world is not None # only called after create_world + spec = SpecBuilder.build(self._world) + self._world._backend_state["spec"] = spec + self._world._model = spec.compile() + self._world._data = mj.MjData(self._world._model) + try: + self._world._backend_state["xml"] = spec.to_xml() + except Exception as xml_err: + # spec.to_xml() is best-effort - if it fails we still have a + # valid compiled model. The cached XML is a convenience for + # tooling, not a correctness invariant. + logger.debug("spec.to_xml() failed: %s", xml_err) + self._world.status = SimStatus.IDLE + + def _recompile_world(self) -> dict[str, Any]: + """Rebuild MjModel from scratch via :meth:`_compile_world`. + + This is the "nuke and pave" path used when the world config changes + in a way that can't be expressed as a spec mutation (e.g. clearing + every body). For incremental changes (add/remove body, camera), + prefer ``_recompile_preserving_state`` in ``scene_ops.py`` which + goes through ``spec.recompile(model, data)`` and preserves joint + state. + """ + try: + self._compile_world() + return {"status": "success"} + except Exception as e: + return {"status": "error", "content": [{"text": f"Recompile failed: {e}"}]} + + # Robot Management + + @staticmethod + def _ensure_meshes(model_path: str, robot_name: str) -> dict[str, Any] | None: + """Check if mesh files referenced by a model XML exist; auto-download if missing. + + Returns ``None`` on success (meshes present or downloaded cleanly) and + a standard error dict on auto-download failure. Caller MUST propagate + the error dict back to the agent - previously the return value was + ignored and the error was silently swallowed, leaving the agent to + hit a cryptic 'mesh not found' from MuJoCo instead. + """ + model_dir = os.path.dirname(os.path.abspath(model_path)) + + files_to_check = [model_path] + try: + with open(model_path) as _f: + top_content = _f.read() + for inc in re.findall(r' dict[str, Any]: + """Add a robot to the simulation via XML round-trip composition. + + Instead of replacing the entire world model, this method merges the + robot's bodies, actuators, assets, and sensors into the existing scene + XML. This preserves previously-created world state (gravity, objects, + cameras, other robots). + """ + if self._world is None: + return {"status": "error", "content": [{"text": "No world. Use action='create_world' first."}]} + if err := self._require_no_running_policy("add_robot"): + return err + if name in self._world.robots: + return {"status": "error", "content": [{"text": f"Robot '{name}' already exists."}]} + + # Resolution precedence: + # 1. explicit `urdf_path` (anything on disk). + # 2. `data_config` looked up in the model registry. + # 3. DEPRECATED: `name` looked up in the registry (undocumented + # fallback kept for one release with a DeprecationWarning). + # Pass `data_config` for new code; the `name`-as-registry-key path + # will be removed. + resolved_path = urdf_path + if not resolved_path and data_config: + resolved_path = resolve_model(data_config) + if not resolved_path: + return { + "status": "error", + "content": [ + { + "text": f"No model found for '{data_config}'.\n💡 Use action='list_urdfs' to see available robots" + } + ], + } + elif not resolved_path and name: + # deprecated fallback - try registry by instance name. + import warnings as _warnings + + resolved_path = resolve_model(name) + if resolved_path: + _warnings.warn( + f"add_robot: resolving model via instance name '{name}' is deprecated; " + "pass data_config='' instead.", + DeprecationWarning, + stacklevel=2, + ) + + if not resolved_path: + return {"status": "error", "content": [{"text": "Either urdf_path or data_config is required."}]} + if not os.path.exists(resolved_path): + return {"status": "error", "content": [{"text": f"File not found: {resolved_path}"}]} + + mj = self._mj + + robot = SimRobot( + name=name, + urdf_path=resolved_path, + position=position or [0.0, 0.0, 0.0], + orientation=orientation or [1.0, 0.0, 0.0, 0.0], + data_config=data_config, + namespace=f"{name}/", + ) + + try: + # Propagate auto-download failure back to the agent instead of + # silently eating it (previously this dict was discarded and + # the next MuJoCo load threw a cryptic 'mesh not found'). + mesh_err = self._ensure_meshes(resolved_path, data_config or name) + if mesh_err is not None: + self._world.robots.pop(name, None) + return mesh_err + + # Register the robot BEFORE attach so scene_ops can re-discover + # its joint/actuator IDs inside the merged model. + self._world.robots[name] = robot + # Track robot base path for asset path resolution. + if not self._world._backend_state.get("robot_base_xml"): + self._world._backend_state["robot_base_xml"] = resolved_path + + # Compose into the live spec via spec.attach(). The helper sets + # robot.joint_names from the source spec (pre-namespacing) and + # then scene_ops._recompile_preserving_state resolves the + # post-attach joint/actuator IDs on the compiled model. + ok = inject_robot_into_scene(self._world, robot, resolved_path) + if not ok: + del self._world.robots[name] + return { + "status": "error", + "content": [{"text": f"Failed to inject robot '{name}' into scene."}], + } + + # Discover cameras that the robot's source MJCF declared. The + # compiled model already has them namespaced under + # ``{robot.name}/``. We probe the post-compile model + # instead of the source, which avoids loading a second model + # just for introspection. + pfx = robot.namespace or "" + model = self._world._model + for i in range(model.ncam): + cam_name = mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) + if not cam_name: + continue + # Strip the robot namespace for our Python-side key - the + # registry is keyed on the short name and we re-attach the + # namespace when passing to the renderer. + short = cam_name[len(pfx) :] if cam_name.startswith(pfx) else cam_name + if short not in self._world.cameras: + self._world.cameras[short] = SimCamera( + name=cam_name, + camera_id=i, + width=self.default_width, + height=self.default_height, + origin_robot=name, + ) + + # leave the freshly-added robot in a clean, deterministic + # zero state (qpos=qvel=ctrl=0) rather than silently settling + # under gravity for 100 steps. Callers that want a pre-settled + # pose should call step()/reset() explicitly. This makes + # `add_robot` -> `get_robot_state` observations meaningful for + # learning pipelines that expect t=0 to be a canonical start. + mj.mj_resetData(self._world._model, self._world._data) + self._world.sim_time = 0.0 + self._world.step_count = 0 + mj.mj_forward(self._world._model, self._world._data) + + source = f"data_config='{data_config}'" if data_config else os.path.basename(resolved_path) + return { + "status": "success", + "content": [ + { + "text": ( + f"🤖 Robot '{name}' added to simulation\n" + f"📁 Source: {source} → {os.path.basename(resolved_path)}\n" + f"📍 Position: {robot.position}\n" + f"🔩 Joints: {len(robot.joint_names)} ({', '.join(robot.joint_names[:8])}{'...' if len(robot.joint_names) > 8 else ''})\n" + f"⚡ Actuators: {len(robot.actuator_ids)}\n" + f"📷 Cameras: {list(self._world.cameras.keys())}\n" + f"💡 Run policy: action='run_policy', robot_name='{name}'" + ) + } + ], + } + except Exception as e: + # Clean up on failure + self._world.robots.pop(name, None) + logger.error("Failed to add robot '%s': %s", name, e) + return {"status": "error", "content": [{"text": f"Failed to load: {e}"}]} + + def remove_robot(self, name: str) -> dict[str, Any]: + """Remove a robot and every element it injected (bodies, actuators, + sensors, equality/tendon refs) from the MJCF scene, then recompile. + + Previously remove_robot only popped the Python-side dict entry, + leaving the robot's MJCF in place. That blocked re-adding a robot + with the same name (MuJoCo rejects duplicates on compile) and left + stale bodies in the physics loop. + + Concurrency (GH #114): this is a *global-scope* mutation - the XML + round-trip reallocates ``model``/``data`` and invalidates cached + actuator/joint IDs held by every running PolicyRunner. We stop the + target robot's own policy first (cooperatively), then require no + OTHER robot is running a policy. + """ + if self._world is None or name not in self._world.robots: + return {"status": "error", "content": [{"text": f"Robot '{name}' not found."}]} + + # Step 1: cooperatively stop THIS robot's policy if running. + # Has to happen before the global check so remove_robot works even + # when the target robot has an active policy (the common case). + if name in self._policy_threads: + self._world.robots[name].policy_running = False + try: + self._policy_threads[name].result(timeout=5.0) + except Exception: + pass + del self._policy_threads[name] + + # Step 2: after stopping our own, there must be no OTHER policy + # running - an XML round-trip will invalidate cached IDs everywhere. + if err := self._require_no_running_policy("remove_robot"): + return err + + # Pop the robot from the registry BEFORE the rebuild - eject_robot_from_scene + # rebuilds the spec from the remaining world.robots dict, so the robot + # we want to drop must no longer be in it. + del self._world.robots[name] + + ejected = eject_robot_from_scene(self._world, name) + if not ejected: + # Unlikely - rebuild from world state with one fewer robot. + return { + "status": "error", + "content": [{"text": f"Failed to eject robot '{name}' from scene."}], + } + + return {"status": "success", "content": [{"text": f"🗑️ Robot '{name}' removed."}]} + + def list_robots(self) -> list[str]: + """Return ordered robot names (SimEngine ABC). + + For the user-facing agent-tool action (rich dict output) see + :meth:`list_robots_info`, which the dispatcher aliases to the + ``list_robots`` action string. + """ + if self._world is None or not self._world.robots: + return [] + return list(self._world.robots.keys()) + + def robot_joint_names(self, robot_name: str) -> list[str]: + """Ordered joint names for ``robot_name`` (SimEngine ABC).""" + if self._world is None or robot_name not in self._world.robots: + return [] + return list(self._world.robots[robot_name].joint_names) + + def list_robots_info(self) -> dict[str, Any]: + """Agent-tool action: pretty-printed robot listing. + + Separate from :meth:`list_robots` (which returns ``list[str]`` for + the SimEngine ABC) because the dispatcher needs a dict-shaped + response for user display. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if not self._world.robots: + return {"status": "success", "content": [{"text": "No robots. Use action='add_robot'."}]} + + lines = ["🤖 Robots in simulation:\n"] + for name, robot in self._world.robots.items(): + status = "🟢 running" if robot.policy_running else "⚪ idle" + lines.append( + f" • {name} ({os.path.basename(robot.urdf_path)})\n" + f" Position: {robot.position}, Joints: {len(robot.joint_names)}, " + f"Config: {robot.data_config or 'direct'}, Status: {status}" + ) + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + def get_robot_state(self, robot_name: str) -> dict[str, Any]: + """canonical name parameter is ``robot_name``. The router + accepts ``name`` as an alias (bidirectional) so legacy LLM calls + keep working, but new tool specs should document only robot_name.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"Robot '{robot_name}' not found."}]} + + mj = self._mj + robot = self._world.robots[robot_name] + model, data = self._world._model, self._world._data + + # Namespace-aware joint lookup (see add_robot / _apply_sim_action). + pfx = robot.namespace or "" + state = {} + for jnt_name in robot.joint_names: + jnt_id = -1 + if pfx: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, pfx + jnt_name) + if jnt_id < 0: + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jnt_name) + if jnt_id >= 0: + state[jnt_name] = { + "position": float(data.qpos[model.jnt_qposadr[jnt_id]]), + "velocity": float(data.qvel[model.jnt_dofadr[jnt_id]]), + } + + text = f"🤖 '{robot_name}' state (t={self._world.sim_time:.3f}s):\n" + for jnt, vals in state.items(): + text += f"{jnt}: pos={vals['position']:.4f}, vel={vals['velocity']:.4f}\n" + + return {"status": "success", "content": [{"text": text}, {"json": {"state": state}}]} + + # Object Management + + def add_object( + self, + name: str, + shape: str = "box", + position: list[float] | None = None, + orientation: list[float] | None = None, + size: list[float] | None = None, + color: list[float] | None = None, + mass: float = 0.1, + is_static: bool | None = None, + mesh_path: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Add an object to the simulation.""" + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if err := self._require_no_running_policy("add_object"): + return err + if name in self._world.objects: + return {"status": "error", "content": [{"text": f"Object '{name}' exists."}]} + + # planes are infinite and must be static. Explicit + # is_static=False for a plane is an error; None or True both + # resolve to True. Non-plane shapes default to dynamic. + if shape == "plane": + if is_static is False: + return { + "status": "error", + "content": [ + { + "text": "add_object: shape='plane' requires is_static=True (planes are infinite and cannot have dynamic mass)." + } + ], + } + is_static = True + elif is_static is None: + is_static = False + + obj = SimObject( + name=name, + shape=shape, + position=position or [0.0, 0.0, 0.0], + orientation=orientation or [1.0, 0.0, 0.0, 0.0], + size=size or [0.05, 0.05, 0.05], + color=color or [0.5, 0.5, 0.5, 1.0], + mass=mass, + mesh_path=mesh_path, + is_static=is_static, + ) + self._world.objects[name] = obj + + # Every scene mutation goes through spec.recompile() - no branching + # on robots / scene_loaded, and no XML round-trip. MjSpec preserves + # existing joint state automatically on recompile. + try: + if not inject_object_into_scene(self._world, obj): + # Injection returned False (compile error). Clean up. + self._world.objects.pop(name, None) + return { + "status": "error", + "content": [{"text": f"Failed to inject '{name}': spec recompile refused."}], + } + except (ValueError, RuntimeError) as e: + self._world.objects.pop(name, None) + return { + "status": "error", + "content": [{"text": f"Failed to inject '{name}' into live scene: {e}"}], + } + + return { + "status": "success", + "content": [ + { + "text": f"📦 '{name}' added: {shape} at {obj.position}, size={obj.size}, {'static' if is_static else f'{mass}kg'}" + } + ], + } + + def remove_object(self, name: str) -> dict[str, Any]: + if self._world is None or name not in self._world.objects: + return {"status": "error", "content": [{"text": f"Object '{name}' not found."}]} + if err := self._require_no_running_policy("remove_object"): + return err + del self._world.objects[name] + # spec-based path: eject_body_from_scene looks up the body in the + # live MjSpec, deletes it, and recompiles preserving remaining state. + eject_body_from_scene(self._world, name) + return {"status": "success", "content": [{"text": f"🗑️ '{name}' removed."}]} + + def move_object( + self, name: str, position: list[float] | None = None, orientation: list[float] | None = None + ) -> dict[str, Any]: + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if name not in self._world.objects: + return {"status": "error", "content": [{"text": f"Object '{name}' not found."}]} + # Guard: move_object writes qpos + calls mj_forward, racing a running policy. + if err := self._require_no_running_policy("move_object"): + return err + + mj = self._mj + model, data = self._world._model, self._world._data + + jnt_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, f"{name}_joint") + if jnt_id >= 0: + qpos_addr = model.jnt_qposadr[jnt_id] + if position: + data.qpos[qpos_addr : qpos_addr + 3] = position + self._world.objects[name].position = position + if orientation: + data.qpos[qpos_addr + 3 : qpos_addr + 7] = orientation + self._world.objects[name].orientation = orientation + mj.mj_forward(model, data) + + return {"status": "success", "content": [{"text": f"📍 '{name}' moved to {position or 'same'}"}]} + + def list_objects(self) -> dict[str, Any]: + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if not self._world.objects: + return {"status": "success", "content": [{"text": "No objects."}]} + + lines = ["📦 Objects:\n"] + for name, obj in self._world.objects.items(): + lines.append(f" • {name}: {obj.shape} at {obj.position}, {'static' if obj.is_static else f'{obj.mass}kg'}") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + # Camera Management + + def add_camera( + self, + name: str, + position: list[float] | None = None, + target: list[float] | None = None, + fov: float = 60.0, + width: int = 640, + height: int = 480, + ) -> dict[str, Any]: + """Add a camera to the scene (MJCF ```` injection). + + Naming: ``add_object(name="X", ...)`` injects its geom as + ``"X_geom"`` in MJCF, so cameras share the name table only with + other cameras and body names - not with object geoms. Duplicate + camera names are rejected upfront. + + Orientation: ``target`` is baked into the camera's ``xyaxes`` + attribute so the rendered view looks at that point (not just + forward-facing). Degenerate cases (target == position) error. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if err := self._require_no_running_policy("add_camera"): + return err + + # validate position / target shape before we bake them into XML. + pos = position or [1.0, 1.0, 1.0] + tgt = target or [0.0, 0.0, 0.0] + for _lbl, _vec in (("position", pos), ("target", tgt)): + try: + if len(_vec) != 3: + return { + "status": "error", + "content": [{"text": f"add_camera: '{_lbl}' must be 3 elements [x,y,z], got {len(_vec)}"}], + } + except TypeError: + return {"status": "error", "content": [{"text": f"add_camera: '{_lbl}' must be a list of 3 numbers"}]} + # Degenerate orientation: position == target means no well-defined look direction. + if all(abs(pos[i] - tgt[i]) < 1e-9 for i in range(3)): + return { + "status": "error", + "content": [ + { + "text": f"add_camera: 'position' and 'target' are identical ({pos}); camera has no look direction." + } + ], + } + + # reject duplicate camera names. Previously a second + # add_camera(name=existing) silently overwrote the registry entry but + # left the XML's unchanged, so the old pose stuck around for + # rendering. Explicit error avoids the surprise. + if name in self._world.cameras: + return { + "status": "error", + "content": [{"text": f"add_camera: camera '{name}' already exists. Remove it first."}], + } + + cam = SimCamera( + name=name, + position=pos, + target=tgt, + fov=fov, + width=width, + height=height, + ) + self._world.cameras[name] = cam + + # Spec-based path: inject_camera_into_scene adds the camera to the + # live spec and recompiles preserving state. + try: + if not inject_camera_into_scene(self._world, cam): + self._world.cameras.pop(name, None) + return { + "status": "error", + "content": [{"text": f"Failed to inject camera '{name}': spec recompile refused."}], + } + except (ValueError, RuntimeError) as e: + self._world.cameras.pop(name, None) + return { + "status": "error", + "content": [{"text": f"Failed to inject camera '{name}' into live scene: {e}"}], + } + + return {"status": "success", "content": [{"text": f"📷 Camera '{name}' added at {cam.position}"}]} + + def remove_camera(self, name: str) -> dict[str, Any]: + """Remove a named camera from the live scene. + + Pops the Python-side registry entry and then deletes the camera + from the MjSpec via :func:`SpecBuilder.remove_camera` so future + renders/compiles no longer see it. + """ + if self._world is None or name not in self._world.cameras: + return {"status": "error", "content": [{"text": f"Camera '{name}' not found."}]} + if err := self._require_no_running_policy("remove_camera"): + return err + cam = self._world.cameras.pop(name) + + spec = self._world._backend_state.get("spec") + if spec is not None: + # Use the namespaced MuJoCo name if we have it (camera came from + # a robot's URDF), else the short name. + mj_name = cam.name or name + SpecBuilder.remove_camera(spec, mj_name) + # Recompile so nbody/ncam in _model match the new spec. + try: + self._world._model, self._world._data = spec.recompile(self._world._model, self._world._data) + try: + self._world._backend_state["xml"] = spec.to_xml() + except Exception: + pass + except (ValueError, RuntimeError) as e: + logger.warning("remove_camera recompile failed: %s", e) + + return {"status": "success", "content": [{"text": f"🗑️ Camera '{name}' removed."}]} + + # Simulation Control + + def step(self, n_steps: int = 1) -> dict[str, Any]: + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # reject negative, accept zero as no-op + if not isinstance(n_steps, int): + try: + n_steps = int(n_steps) + except (TypeError, ValueError): + return { + "status": "error", + "content": [{"text": f"step: n_steps must be an integer, got {type(n_steps).__name__}"}], + } + if n_steps < 0: + return {"status": "error", "content": [{"text": f"step: n_steps must be >= 0, got {n_steps}"}]} + if n_steps == 0: + return { + "status": "success", + "content": [ + {"text": f"⏩ +0 steps (no-op) | t={self._world.sim_time:.4f}s | total={self._world.step_count}"} + ], + } + mj = self._mj + with self._lock: + for _ in range(n_steps): + mj.mj_step(self._world._model, self._world._data) + self._world.sim_time = self._world._data.time + self._world.step_count += n_steps + return { + "status": "success", + "content": [ + {"text": f"⏩ +{n_steps} steps | t={self._world.sim_time:.4f}s | total={self._world.step_count}"} + ], + } + + def reset(self) -> dict[str, Any]: + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # reset during a running policy races mj_step -> SEGFAULT risk + if err := self._require_no_running_policy("reset"): + return err + mj = self._mj + with self._lock: + mj.mj_resetData(self._world._model, self._world._data) + self._world.sim_time = 0.0 + self._world.step_count = 0 + # Flip policy_running flag inside the lock so a racing worker + # thread cannot slip in one more mj_step between reset and flag + # flip. + for r in self._world.robots.values(): + r.policy_running = False + r.policy_steps = 0 + return {"status": "success", "content": [{"text": "🔄 Reset to initial state."}]} + + def get_state(self) -> dict[str, Any]: + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + lines = [ + "🌍 Simulation State", + f"🕐 t={self._world.sim_time:.4f}s (step {self._world.step_count})", + f"⚙️ dt={self._world.timestep}s | 🌐 g={self._world.gravity}", + f"🤖 Robots: {len(self._world.robots)} | 📦 Objects: {len(self._world.objects)} | 📷 Cameras: {len(self._world.cameras)}", + ] + if self._world._model: + lines.append( + f"🦴 Bodies: {self._world._model.nbody} | 🔩 Joints: {self._world._model.njnt} | ⚡ Actuators: {self._world._model.nu}" + ) + if self._world._backend_state.get("recording", False): + lines.append(f"🔴 Recording: {len(self._world._backend_state['trajectory'])} steps") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + def destroy(self) -> dict[str, Any]: + if self._world is None: + return {"status": "success", "content": [{"text": "No world to destroy."}]} + for r in self._world.robots.values(): + r.policy_running = False + self._close_viewer() + self._close_main_thread_renderers() + self._world = None + return {"status": "success", "content": [{"text": "🗑️ World destroyed."}]} + + def _close_main_thread_renderers(self) -> None: + """Close any renderers this thread owns and drop the TLS cache. + + Only safe for the main thread because ``mujoco.Renderer`` binds a + CGL/GLX context to the thread that created it; closing from another + thread can SIGSEGV in ``cgl.free()``. Worker threads drop their + renderers via ``threading.Thread`` teardown. + """ + tls = getattr(self, "_renderer_tls", None) + if tls is None: + return + renderers = getattr(tls, "renderers", None) + if renderers: + for r in list(renderers.values()): + try: + r.close() + except Exception: + pass + renderers.clear() + # Forget the model marker so the next _get_renderer() rebuilds fresh. + if hasattr(tls, "model"): + tls.model = None + + def set_gravity(self, gravity: list[float] | float | int) -> dict[str, Any]: + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + # set_gravity during a running policy races the worker thread + if err := self._require_no_running_policy("set_gravity"): + return err + # validate length/dtype before numpy broadcast + if isinstance(gravity, (int, float)): + gravity = [0.0, 0.0, float(gravity)] + try: + if len(gravity) != 3: + return { + "status": "error", + "content": [ + {"text": f"set_gravity: 'gravity' must be a 3-element list [x,y,z], got {len(gravity)}"} + ], + } + gravity = [float(g) for g in gravity] + except (TypeError, ValueError) as e: + return { + "status": "error", + "content": [{"text": f"set_gravity: 'gravity' must be a 3-element list of numbers ({e})"}], + } + with self._lock: + self._world._model.opt.gravity[:] = gravity + self._world.gravity = gravity + return {"status": "success", "content": [{"text": f"🌐 Gravity: {gravity}"}]} + + def set_timestep(self, timestep: float) -> dict[str, Any]: + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if err := self._require_no_running_policy("set_timestep"): + return err + # reject non-positive; warn on huge values + try: + timestep = float(timestep) + except (TypeError, ValueError): + return { + "status": "error", + "content": [{"text": f"set_timestep: must be a positive number, got {timestep!r}"}], + } + if timestep <= 0: + return {"status": "error", "content": [{"text": f"set_timestep: must be > 0, got {timestep}"}]} + warn = "" + if timestep > 0.1: + warn = f" ⚠️ unusually large timestep (>{0.1}s); physics may be unstable" + with self._lock: + self._world._model.opt.timestep = timestep + self._world.timestep = timestep + return {"status": "success", "content": [{"text": f"⏱️ Timestep: {timestep}s ({1 / timestep:.0f}Hz){warn}"}]} + + # Viewer + + def open_viewer(self) -> dict[str, Any]: + if self._world is None or self._world._model is None: + return {"status": "error", "content": [{"text": "No simulation to view."}]} + from strands_robots.simulation.mujoco.backend import _mujoco_viewer + + if _mujoco_viewer is None: + return {"status": "error", "content": [{"text": "mujoco.viewer not available."}]} + if self._viewer_handle is not None: + return {"status": "success", "content": [{"text": "👁️ Viewer already open."}]} + try: + self._viewer_handle = _mujoco_viewer.launch_passive(self._world._model, self._world._data) + return {"status": "success", "content": [{"text": "👁️ Interactive viewer opened."}]} + except Exception as e: + return {"status": "error", "content": [{"text": f"Viewer failed: {e}"}]} + + def _close_viewer(self) -> None: + if self._viewer_handle is not None: + try: + self._viewer_handle.close() + except Exception: + pass + self._viewer_handle = None + + def close_viewer(self) -> dict[str, Any]: + self._close_viewer() + return {"status": "success", "content": [{"text": "👁️ Viewer closed."}]} + + # URDF Registry + + def list_urdfs(self) -> dict[str, Any]: + return {"status": "success", "content": [{"text": list_available_models()}]} + + def register_urdf(self, data_config: str, urdf_path: str) -> dict[str, Any]: + """validate urdf_path before handing it to the registry. + + The router already rejects missing required params, so the + no-args case produces a friendly 'requires parameter ...' message + without hitting this body. + """ + if not urdf_path: + return { + "status": "error", + "content": [{"text": "register_urdf: 'urdf_path' must be a non-empty string."}], + } + p = Path(urdf_path) + if not p.exists(): + return { + "status": "error", + "content": [{"text": f"register_urdf: file not found: {urdf_path}"}], + } + if not p.is_file(): + return { + "status": "error", + "content": [{"text": f"register_urdf: not a file: {urdf_path}"}], + } + try: + # Smoke-check readability - mj.MjModel.from_xml_path will surface a + # better error later, but permission issues are worth catching now. + with p.open("rb"): + pass + except OSError as e: + return { + "status": "error", + "content": [{"text": f"register_urdf: cannot read {urdf_path}: {e}"}], + } + + _register_urdf(data_config, urdf_path) + resolved = resolve_model(data_config) + return { + "status": "success", + "content": [{"text": f"📋 Registered '{data_config}' → {urdf_path}\nResolved: {resolved or 'NOT FOUND'}"}], + } + + # Introspection + + def get_features(self, robot_name: str | None = None) -> dict[str, Any]: + """Describe the simulation's joints / actuators / cameras / robots. + + If ``robot_name`` is given, the joint / actuator / camera listings + are restricted to that robot (its namespaced MuJoCo names). The + ``robots`` map is also filtered to just that entry. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + mj = self._mj + model = self._world._model + + # All-model name pools + all_joint_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_JOINT, i) for i in range(model.njnt)] + all_joint_names = [n for n in all_joint_names if n] + all_actuator_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_ACTUATOR, i) for i in range(model.nu)] + all_actuator_names = [n for n in all_actuator_names if n] + all_camera_names = [mj.mj_id2name(model, mj.mjtObj.mjOBJ_CAMERA, i) for i in range(model.ncam)] + all_camera_names = [n for n in all_camera_names if n] + + if robot_name is not None: + if robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"Robot '{robot_name}' not found."}]} + robot = self._world.robots[robot_name] + ns = (getattr(robot, "namespace", "") or "").rstrip("/") + prefix = f"{ns}/" if ns else "" + + def _scoped(pool: list[str]) -> list[str]: + if not prefix: + # Single-robot scene with no namespace: return the robot's own + # joints/actuators from the robot model rather than the pool. + return pool + return [n for n in pool if n.startswith(prefix)] + + joint_names = robot.joint_names or _scoped(all_joint_names) + actuator_names = _scoped(all_actuator_names) + camera_names = _scoped(all_camera_names) + + robots_info = { + robot_name: { + "joint_names": robot.joint_names, + "n_joints": len(robot.joint_names), + "n_actuators": len(robot.actuator_ids), + "data_config": robot.data_config, + "source": os.path.basename(robot.urdf_path), + } + } + else: + joint_names = all_joint_names + actuator_names = all_actuator_names + camera_names = all_camera_names + + robots_info = {} + for rname, robot in self._world.robots.items(): + robots_info[rname] = { + "joint_names": robot.joint_names, + "n_joints": len(robot.joint_names), + "n_actuators": len(robot.actuator_ids), + "data_config": robot.data_config, + "source": os.path.basename(robot.urdf_path), + } + + features = { + "n_bodies": model.nbody, + "n_joints": model.njnt, + "n_actuators": model.nu, + "n_cameras": model.ncam, + "timestep": model.opt.timestep, + "joint_names": joint_names, + "actuator_names": actuator_names, + "camera_names": camera_names, + "robots": robots_info, + } + + lines = [ + "🔍 Simulation Features", + f"🦴 Joints ({model.njnt}): {', '.join(joint_names[:12])}{'...' if len(joint_names) > 12 else ''}", + f"⚡ Actuators ({model.nu}): {', '.join(actuator_names[:12])}{'...' if len(actuator_names) > 12 else ''}", + f"📷 Cameras ({model.ncam}): {', '.join(camera_names) if camera_names else 'none (free camera only)'}", + f"⏱️ Timestep: {model.opt.timestep}s ({1 / model.opt.timestep:.0f}Hz)", + ] + for rname, rinfo in robots_info.items(): + lines.append( + f"🤖 {rname}: {rinfo['n_joints']} joints, {rinfo['n_actuators']} actuators ({rinfo['source']})" + ) + + return { + "status": "success", + "content": [{"text": "\n".join(lines)}, {"json": {"features": features}}], + } + + # AgentTool Interface + + @property + def tool_name(self) -> str: + return self.tool_name_str + + @property + def tool_type(self) -> str: + return "simulation" + + def _require_world(self) -> dict[str, Any] | None: + """Return unified 'no world' error or None if world is live. + + Replaces scattered ``"No simulation."`` / ``"No world."`` strings. Every + action that touches ``self._world`` / ``self._world._model`` / + ``self._world._data`` should call this first. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return { + "status": "error", + "content": [{"text": ("No world. Call create_world (or load_scene) first.")}], + } + return None + + def _prune_done_futures(self) -> None: + """Drop completed Future refs from self._policy_threads. + + Without this, list_policies_running and stale-active checks see + historical entries forever (see GH #120). + """ + done = [k for k, f in self._policy_threads.items() if f.done()] + for k in done: + self._policy_threads.pop(k, None) + + def _active_policy_robots(self) -> list[str]: + """Names of robots with a live (not-done) policy Future. + + Prunes stale entries as a side-effect so the returned list is + authoritative. Callers can introspect via ``list_policies_running``. + """ + self._prune_done_futures() + return list(self._policy_threads.keys()) + + def _require_no_running_policy(self, action_name: str, robot_name: str | None = None) -> dict[str, Any] | None: + """Return an error dict if a disallowed policy is running, else None. + + Two scopes (GH #114): + + * ``robot_name=None`` (default) - **global scope**. Used by scene + mutations that touch the whole XML / model pointer (``add_robot``, + ``remove_robot``, ``add_object``, ``remove_object``, ``move_object``, + ``add_camera``, ``remove_camera``, ``load_scene``, ``set_gravity``, + ``set_timestep``). An XML round-trip swaps ``self._world._model`` + and ``self._world._data``; any live PolicyRunner worker holding + pointers to the old arrays will segfault when it next calls + ``mj_step``. Hard-fail. + + * ``robot_name="..."`` - **per-robot scope**. Used by actions that + are safe to run while *other* robots' policies are active + (start_policy on the same robot, stop_policy, etc.). Policies on + different robots can execute concurrently because MuJoCo physics + is serialized by ``self._lock`` and each robot writes to a + disjoint slice of ``data.ctrl[]``. + """ + self._prune_done_futures() + if robot_name is not None: + fut = self._policy_threads.get(robot_name) + if fut is not None and not fut.done(): + return { + "status": "error", + "content": [ + { + "text": ( + f"Cannot '{action_name}' on '{robot_name}' while its policy is running. " + f"Stop it first: action='stop_policy', name='{robot_name}'." + ) + } + ], + } + return None + + active = [name for name, f in self._policy_threads.items() if not f.done()] + if active: + names = ", ".join(f"'{n}'" for n in active) + return { + "status": "error", + "content": [ + { + "text": ( + f"Cannot '{action_name}' while a policy is running on {names}. " + "Stop it first: action='stop_policy'." + ) + } + ], + } + return None + + @property + def tool_spec(self) -> ToolSpec: + # schema cached at module load; see _TOOL_SPEC_SCHEMA + return { + "name": self.tool_name_str, + "description": ( + "Programmatic MuJoCo simulation environment (stateful session). " + "One world per instance; actions form an implicit state machine starting with " + "create_world. Scene mutations (add_robot, remove_robot, add_object, remove_object, move_object, add_camera, remove_camera, " + "load_scene) are blocked while a policy is running - stop it first. " + "Create worlds, add robots from URDF " + "(direct path or auto-resolve from data_config name), add objects, run VLA policies, " + "render cameras, record trajectories, domain randomize. " + "Same Policy ABC as real robot control - sim ↔ real with zero code changes. " + "Actions: create_world, load_scene, reset, get_state, destroy, " + "add_robot, remove_robot, list_robots, get_robot_state, " + "add_object, remove_object, move_object, list_objects, " + "add_camera, remove_camera, " + "run_policy, start_policy, stop_policy, list_policies_running, " + "render, render_depth, render_all, get_contacts, " + "step, set_gravity, set_timestep, " + "randomize, " + "start_recording, stop_recording, get_recording_status, start_cameras_recording, stop_cameras_recording, get_cameras_recording_status, " + "open_viewer, close_viewer, " + "list_urdfs, register_urdf, get_features. " + "Call destroy() at session end to release resources." + ), + "inputSchema": {"json": _TOOL_SPEC_SCHEMA}, + } + + async def stream( + self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any + ) -> AsyncGenerator[ToolResultEvent, None]: + try: + tool_use_id = tool_use.get("toolUseId", "") + input_data = tool_use.get("input", {}) + result = self._dispatch_action(input_data.get("action", ""), input_data) + yield ToolResultEvent(dict(toolUseId=tool_use_id, **result)) # type: ignore[typeddict-item] + except Exception as e: + yield ToolResultEvent( + { + "toolUseId": tool_use.get("toolUseId", ""), + "status": "error", + "content": [{"text": f"Sim error: {e}"}], + } + ) + + # Policy orchestration overrides (MuJoCo-specific wiring) + + def start_policy( + self, + robot_name: str, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + duration: float = 10.0, + control_frequency: float = 50.0, + action_horizon: int = 8, + fast_mode: bool = False, + video: dict[str, Any] | None = None, + policy_object: "Policy | None" = None, + n_steps: int | None = None, + max_steps: int | None = None, + ) -> dict[str, Any]: + """Start policy execution on a background thread (non-blocking). + + MuJoCo override: reuses the ThreadPoolExecutor owned by + ``Simulation`` so agent tools can kick off long-running policies + without blocking the event loop. + + Concurrency (GH #114): multiple policies can run simultaneously on + *different* robots. MuJoCo's ``mj_step`` and ``ctrl[]`` writes are + still serialized via ``self._lock`` (MuJoCo ``model``/``data`` are + not thread-safe for concurrent mutation), but each robot owns a + disjoint slice of ``data.ctrl[]`` so there's no semantic conflict. + + A second ``start_policy`` on the *same* robot is still rejected. + + accepts ``n_steps`` (primary) or legacy ``max_steps`` as an + alternate horizon specification; run_policy converts to duration. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + if robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"Robot '{robot_name}' not found."}]} + + # Per-robot gate: another policy running on a DIFFERENT robot is fine. + if err := self._require_no_running_policy("start_policy", robot_name=robot_name): + return err + + future = self._executor.submit( + self.run_policy, + robot_name, + policy_provider=policy_provider, + policy_config=policy_config, + instruction=instruction, + duration=duration, + control_frequency=control_frequency, + action_horizon=action_horizon, + fast_mode=fast_mode, + video=video, + policy_object=policy_object, + n_steps=n_steps, + max_steps=max_steps, + ) + self._policy_threads[robot_name] = future + + return { + "status": "success", + "content": [{"text": f"🚀 Policy started on '{robot_name}' (async)"}], + } + + def _make_run_policy_hook(self, robot_name: str, instruction: str): + """MuJoCo override: recording + policy_running flag + lock. + + Returns an ``on_frame(step, obs, action)`` closure that: + * flips ``robot.policy_running`` so ``stop_policy`` can interrupt, + * appends to ``_backend_state["trajectory"]`` when recording, + * forwards frames to the LeRobot ``dataset_recorder`` if attached, + * raises ``PolicyStopped`` when the user calls ``stop_policy``. + """ + import numpy as np + + from strands_robots.simulation.models import TrajectoryStep + + world = self._world + if world is None or robot_name not in world.robots: + return None + + robot = world.robots[robot_name] + robot.policy_running = True + robot.policy_instruction = instruction + robot.policy_steps = 0 + + lock = self._lock + + def _hook(step: int, observation: dict[str, Any], action: dict[str, Any]) -> None: + # Cooperative cancellation: stop_policy flips this flag. + if not robot.policy_running: + raise CooperativeStop(f"Policy stopped on '{robot_name}'") + + robot.policy_steps = step + 1 + + with lock: + if world._backend_state.get("recording", False): + world._backend_state["trajectory"].append( + TrajectoryStep( + timestamp=time.time(), + sim_time=world.sim_time, + robot_name=robot_name, + observation={k: v for k, v in observation.items() if not isinstance(v, np.ndarray)}, + action=action, + instruction=instruction, + ) + ) + rec = world._backend_state.get("dataset_recorder") + if rec is not None: + rec.add_frame(observation=observation, action=action, task=instruction) + + return _hook + + def run_policy( + self, + robot_name: str, + policy_provider: str = "mock", + policy_config: dict[str, Any] | None = None, + instruction: str = "", + duration: float = 10.0, + control_frequency: float = 50.0, + action_horizon: int = 8, + fast_mode: bool = False, + video: dict[str, Any] | None = None, + policy_object: "Policy | None" = None, + n_steps: int | None = None, + max_steps: int | None = None, + max_onframe_failures: int | None = None, + ) -> dict[str, Any]: + """MuJoCo ``run_policy`` override: pre-flight world check + graceful stop. + + Delegates to :meth:`SimEngine.run_policy` but clears the MuJoCo + ``policy_running`` flag in a ``finally`` clause and swallows + ``_PolicyStopped`` (which the ``on_frame`` hook raises on user + cancellation) into a normal "policy stopped" result. + + forwards ``n_steps`` / ``max_steps`` to the base so LLM callers + can specify horizon in steps rather than wall-clock seconds. + """ + if self._world is None or self._world._model is None or self._world._data is None: + return {"status": "error", "content": [{"text": "No world. Call create_world (or load_scene) first."}]} + + try: + return super().run_policy( + robot_name, + policy_provider=policy_provider, + policy_config=policy_config, + instruction=instruction, + duration=duration, + control_frequency=control_frequency, + action_horizon=action_horizon, + fast_mode=fast_mode, + video=video, + policy_object=policy_object, + n_steps=n_steps, + max_steps=max_steps, + max_onframe_failures=max_onframe_failures, + ) + finally: + if self._world is not None and robot_name in self._world.robots: + self._world.robots[robot_name].policy_running = False + + # Action name aliases (tool-action -> method-name) + _ACTION_ALIASES = { + "list_robots": "list_robots_info", + } + + # Input field name -> method parameter name (syntactic sugar for the LLM) + _FIELD_ALIASES = { + "checkpoint_name": "name", + "torque_vec": "torque", + } + + # Params the router passes through but not every method declares. + # These are used for cross-cutting concerns (e.g. video on run_policy) + # and must not be reported as "unknown" by the router. + _ROUTER_PASSTHROUGH = {"action"} + + # Vector params with expected length (for dimension validation before + # numpy/MuJoCo sees them). Length 3 = xyz unless noted. + _VECTOR_PARAM_LENGTHS: dict[str, int] = { + "position": 3, + "target": 3, + "origin": 3, + "force": 3, + "torque": 3, + "torque_vec": 3, + "gravity": 3, + "direction": 3, + "point": 3, + "orientation": 4, # quaternion (w,x,y,z) + "color": 4, # rgba + } + + def _validate_and_build_kwargs( + self, action: str, method_name: str, sig: inspect.Signature, remapped: dict[str, Any] + ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: + """Validate input against method signature; return (kwargs, error_result). + + Exactly one of the tuple elements is non-None. + """ + # Strip self + VAR_POSITIONAL (*args) + VAR_KEYWORD (**kwargs) for signature + # introspection; **kwargs methods accept arbitrary inputs, so we skip the + # unknown-key check for them. + named_params = { + n: p + for n, p in sig.parameters.items() + if n != "self" and p.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) + } + method_has_var_keyword = any(p.kind is inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + method_param_names = set(named_params) + accepted_field_names = method_param_names | set(self._FIELD_ALIASES.keys()) | self._ROUTER_PASSTHROUGH + + # run_policy folds flat video keys into a structured `video` dict; those + # flat keys are legitimate at the router boundary even though run_policy + # itself takes `video=`. + if action == "run_policy": + accepted_field_names |= {"output_path", "fps", "camera_name"} + + # name/robot_name are aliased in both directions in the legacy router; + # allow either here so we don't flag the alias as unknown. + if "name" in method_param_names: + accepted_field_names.add("robot_name") + if "robot_name" in method_param_names: + accepted_field_names.add("name") + + # 1) Unknown kwargs (skipped for **kwargs methods which legitimately passthrough) + unknown = [] if method_has_var_keyword else [k for k in remapped if k not in accepted_field_names] + if unknown: + valid_sorted = sorted(method_param_names - {"action"}) + return None, { + "status": "error", + "content": [ + {"text": (f"Unknown parameter '{unknown[0]}' for action '{action}'. Valid: {valid_sorted}")} + ], + } + + # 2) Vector dimension validation (applies before method runs) + for vparam, expected_len in self._VECTOR_PARAM_LENGTHS.items(): + if vparam not in remapped: + continue + val = remapped[vparam] + if val is None: + continue + if not hasattr(val, "__len__"): + return None, { + "status": "error", + "content": [{"text": f"Parameter '{vparam}' must be a list of {expected_len} numbers."}], + } + if len(val) != expected_len: + return None, { + "status": "error", + "content": [ + {"text": (f"Parameter '{vparam}' must be a list of {expected_len} numbers, got {len(val)}.")} + ], + } + for i, component in enumerate(val): + if not isinstance(component, (int, float)) or isinstance(component, bool): + return None, { + "status": "error", + "content": [ + {"text": (f"Parameter '{vparam}'[{i}] must be numeric, got {type(component).__name__}.")} + ], + } + + # 3) Build kwargs + check required params + kwargs: dict[str, Any] = {} + for param_name, param in named_params.items(): + if param_name == "name" and "name" not in remapped and "robot_name" in remapped: + kwargs["name"] = remapped["robot_name"] + elif param_name == "robot_name" and "robot_name" not in remapped and "name" in remapped: + kwargs["robot_name"] = remapped["name"] + elif param_name in remapped: + kwargs[param_name] = remapped[param_name] + elif param.default is inspect.Parameter.empty: + return None, { + "status": "error", + "content": [{"text": f"Action '{action}' requires parameter '{param_name}'."}], + } + + return kwargs, None + + def _dispatch_action(self, action: str, d: dict[str, Any]) -> dict[str, Any]: + """Route action to the matching method with full input validation. + + Validation layer: + * unknown top-level params are rejected with a friendly message, + * missing required params produce a "requires parameter X" error + (no raw Python ``TypeError``), + * vector params have length + numeric dtype checked before the + value reaches numpy / MuJoCo. + + Policy-provider kwargs are nested under ``policy_config`` (never + top-level) so the dispatcher stays backend-agnostic. + """ + method_name = self._ACTION_ALIASES.get(action, action) + method = getattr(self, method_name, None) + + if method is None or action.startswith("_"): + return {"status": "error", "content": [{"text": f"Unknown action: {action}"}]} + + cache = getattr(self, "_sig_cache", None) + if cache is None: + self._sig_cache = cache = {} + if method_name not in cache: + cache[method_name] = inspect.signature(method) + sig = cache[method_name] + + # Field-alias rewriting (before validation so the validator sees + # canonical names). + remapped = {k: v for k, v in d.items() if k != "action"} + for field_key, param_key in self._FIELD_ALIASES.items(): + if field_key in remapped and param_key not in remapped: + remapped[param_key] = remapped.pop(field_key) + + # Fold flat video keys into `video` dict for run_policy/start_policy. + if action in ("run_policy", "start_policy") and "video" not in remapped: + _video_flat: dict[str, Any] = {} + if "output_path" in remapped: + _video_flat["path"] = remapped.pop("output_path") + if "fps" in remapped: + _video_flat["fps"] = remapped.pop("fps") + # camera_name is shared with render(); only treat as video camera + # when paired with an output path. + if _video_flat.get("path") and "camera_name" in remapped: + _video_flat["camera"] = remapped.pop("camera_name") + if _video_flat.get("path"): + remapped["video"] = _video_flat + + kwargs, err = self._validate_and_build_kwargs(action, method_name, sig, remapped) + if err is not None: + return err + assert kwargs is not None + return method(**kwargs) + + def stop_policy(self, robot_name: str = "") -> dict[str, Any]: + """Stop a running policy on the given robot (cooperative cancellation). + + Counterpart to :meth:`start_policy`. Flips the robot's + ``policy_running`` flag; the background loop in + :meth:`_run_policy_loop` sees it and raises :class:`PolicyStopped` + which is caught cleanly inside :meth:`start_policy`. + + idempotent - if the robot exists but no policy is running, we + still return success with 'Was not running' so callers can call + stop_policy unconditionally. The only error case is an unknown + robot_name. + + empty robot_name returns a clear error instead of a silent + match against the first robot. + """ + if not robot_name: + return { + "status": "error", + "content": [{"text": "stop_policy requires 'robot_name'."}], + } + if self._world is None or robot_name not in self._world.robots: + return {"status": "error", "content": [{"text": f"Robot '{robot_name}' not found."}]} + robot = self._world.robots[robot_name] + was_running = robot.policy_running + robot.policy_running = False + msg = f"Stopped on '{robot_name}'" if was_running else f"Was not running on '{robot_name}'" + return {"status": "success", "content": [{"text": msg}]} + + def list_policies_running(self) -> dict[str, Any]: + """Return the names of robots currently running a policy. + + Useful for inspecting concurrent-policy state when running two or + more VLA arms in the same scene (GH #114). Always returns a + success dict so the LLM can parse it uniformly. Prunes stale + completed Future entries as a side effect. + """ + active = self._active_policy_robots() + if not active: + return { + "status": "success", + "content": [{"text": "⚪ No policies running."}], + } + robot_lines = "\n".join(f" • 🟢 {n}" for n in active) + return { + "status": "success", + "content": [{"text": f"🟢 Active policies ({len(active)}):\n{robot_lines}"}], + } + + # Cleanup + + # Default cleanup shutdown timeout (seconds). A policy worker might be + # mid-step when cleanup is called; give it bounded time to see the + # cooperative-stop flag and exit cleanly before we null the world and + # its in-flight ``mj_step`` segfaults on a nulled ``_model``/``_data``. + # Override in tests via ``cleanup(policy_stop_timeout=...)`` if needed. + _DEFAULT_POLICY_STOP_TIMEOUT = 5.0 + + def cleanup(self, policy_stop_timeout: float | None = None) -> None: + """Release every resource owned by this Simulation instance. + + Concurrency (GH #116): nulling ``self._world`` while a policy worker + thread is still inside ``mj_step(world._model, world._data)`` is a + SIGSEGV waiting to happen. Previously cleanup called + ``executor.shutdown(wait=False)`` right after setting + ``self._world = None``, which meant the worker could still be + holding stale pointers to freed arrays. The + ``policy_running = False`` flag was flipped but never awaited. + + New order: + 1. Signal every live policy to stop (``policy_running = False``). + 2. Await each outstanding Future with a bounded timeout - the + ``on_frame`` hook sees the flag at the top of its next call + and raises ``CooperativeStop`` which short-circuits run_policy. + 3. Any Future still not-done after the timeout: we log a warning + and proceed - at that point the worker is wedged somewhere + outside MuJoCo and a stale-pointer segfault is the lesser evil + than hanging the host process on exit. + 4. Only AFTER workers have unwound do we null ``self._world`` + and tear down renderers / the viewer / the executor. + + Args: + policy_stop_timeout: Seconds to wait per active policy future. + ``None`` (default) uses + ``_DEFAULT_POLICY_STOP_TIMEOUT`` (5s). Set to a small value + in tests that want fast teardown. + """ + if hasattr(self, "mesh") and self.mesh: + self.mesh.stop() + + timeout = policy_stop_timeout if policy_stop_timeout is not None else self._DEFAULT_POLICY_STOP_TIMEOUT + + # Step 1 + 2: cooperative stop + bounded join BEFORE nulling world. + # The ``policy_running`` flag is read by the MuJoCo-specific + # ``_make_run_policy_hook`` at the top of its next call; setting + # it here makes the worker raise CooperativeStop at its next step. + if self._world is not None: + for r in self._world.robots.values(): + r.policy_running = False + + # Prune completed futures so we only wait on genuinely-live ones. + self._prune_done_futures() + if self._policy_threads: + for robot_name, fut in list(self._policy_threads.items()): + try: + fut.result(timeout=timeout) + except Exception as e: + # result() raises either the worker's exception OR a + # TimeoutError. Log and continue - we want cleanup to + # finish even on pathological workers. + logger.warning( + "cleanup: policy on '%s' did not stop within %.1fs: %s", + robot_name, + timeout, + e, + ) + self._policy_threads.clear() + + # Step 3: now it's safe to null the world. Any worker still alive + # at this point has already escaped MuJoCo (we've confirmed via + # fut.result()), so a nulled _model / _data is no longer racy. + if self._world: + self._world = None + + self._close_viewer() + # close main-thread renderers before dropping the TLS object. + # Renderers created on worker threads release their GL contexts + # when those threads terminate; calling close() cross-thread + # SIGSEGVs in cgl.free(), so we stay on main. + self._close_main_thread_renderers() + if hasattr(self, "_renderer_tls"): + self._renderer_tls = threading.local() + # Step 4: shut the executor down now that all our policy futures + # are either completed or abandoned. wait=False is OK at this + # point because we've already drained policy workers above - any + # remaining thread is render / observation work that's safe to + # outlive us. + self._executor.shutdown(wait=False) + self._shutdown_event.set() + + def __enter__(self) -> "Simulation": + return self + + def __exit__(self, *exc: object) -> None: + self.cleanup() + + def __del__(self) -> None: + try: + self.cleanup() + except Exception: + pass diff --git a/strands_robots/simulation/mujoco/spec_builder.py b/strands_robots/simulation/mujoco/spec_builder.py new file mode 100644 index 0000000..1a23406 --- /dev/null +++ b/strands_robots/simulation/mujoco/spec_builder.py @@ -0,0 +1,447 @@ +"""MjSpec-based MJCF builder - programmatic scene construction via the MuJoCo AST. + +This is the ONLY path for building / mutating MuJoCo scenes in strands-robots. +It replaces the string-concat ``MJCFBuilder`` (deleted) and the XML-round-trip +helpers in ``scene_ops.py``: + +- ``SpecBuilder.build(world)``: build a fresh ``MjSpec`` from a ``SimWorld``. +- ``add_object`` / ``remove_body`` / ``add_camera``: mutate an existing spec. +- ``attach_robot``: compose a URDF/MJCF file into a scene with a name prefix. +- ``replace_scene``: load an agent-authored MJCF string as the new scene. + +All builders return a ``MjSpec`` that the caller compiles via ``spec.compile()`` +or re-compiles in-place via ``spec.recompile(model, data)`` (which preserves +existing joint state automatically). + +This module does NOT import any XML / ElementTree / regex machinery - every +transformation goes through MuJoCo's own AST. +""" + +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np + +from strands_robots.simulation.models import SimCamera, SimObject, SimRobot, SimWorld +from strands_robots.simulation.mujoco.backend import _ensure_mujoco + +logger = logging.getLogger(__name__) + + +# MuJoCo geom-type enum mapping. Populated lazily on first call so module +# import doesn't require mujoco to be installed (backend _ensure_mujoco gates). +_GEOM_TYPE_CACHE: dict[str, int] | None = None + + +def _geom_type(shape: str) -> int: + """Map our shape-name vocabulary to MuJoCo's ``mjtGeom`` enum. + + Raises ValueError for shapes unsupported by the current pipeline. New + shapes (``ellipsoid``, ``hfield``) can be added here without touching + the rest of the builder. + """ + global _GEOM_TYPE_CACHE + if _GEOM_TYPE_CACHE is None: + mujoco = _ensure_mujoco() + _GEOM_TYPE_CACHE = { + "box": mujoco.mjtGeom.mjGEOM_BOX, + "sphere": mujoco.mjtGeom.mjGEOM_SPHERE, + "cylinder": mujoco.mjtGeom.mjGEOM_CYLINDER, + "capsule": mujoco.mjtGeom.mjGEOM_CAPSULE, + "ellipsoid": mujoco.mjtGeom.mjGEOM_ELLIPSOID, + "mesh": mujoco.mjtGeom.mjGEOM_MESH, + "plane": mujoco.mjtGeom.mjGEOM_PLANE, + } + try: + return _GEOM_TYPE_CACHE[shape] + except KeyError as e: + supported = ", ".join(sorted(_GEOM_TYPE_CACHE.keys())) + raise ValueError(f"Unsupported shape {shape!r}. Supported: {supported}.") from e + + +def _normalize_size(shape: str, size: list[float]) -> list[float]: + """Convert SimObject ``size`` convention to MuJoCo's per-geom size vector. + + MuJoCo's geom-size conventions (all in the LOCAL frame): + + * ``box``: half-extents ``[hx, hy, hz]`` + * ``sphere``: ``[radius]`` (MuJoCo uses size[0] as radius) + * ``cylinder``: ``[radius, half-height]`` + * ``capsule``: ``[radius, half-height]`` (cap hemisphere radius = radius) + * ``ellipsoid``: ``[rx, ry, rz]`` + * ``plane``: ``[hx, hy, grid_spacing]`` (hx/hy are half-sizes) + * ``mesh``: ``[]`` (mesh asset dictates extent; size ignored) + + ``SimObject.size`` is always 3 floats. Box/ellipsoid use all 3 as full + extents, sphere uses ``size[0]`` as diameter (MuJoCo halves it to radius), + cylinder/capsule use ``size[0]`` as diameter and ``size[2]`` as full height + (both halved), plane uses ``size[0]``/``size[1]`` as full extents (halved). + """ + if shape == "box": + sx, sy, sz = size if len(size) >= 3 else (0.1, 0.1, 0.1) + return [sx / 2, sy / 2, sz / 2] + if shape == "sphere": + # Legacy builder used size[0]/2 as radius - preserve that. + radius = size[0] / 2 if size else 0.025 + return [radius, 0.0, 0.0] + if shape in ("cylinder", "capsule"): + radius = size[0] / 2 if size else 0.025 + half_h = size[2] / 2 if len(size) > 2 else 0.05 + return [radius, half_h, 0.0] + if shape == "ellipsoid": + sx, sy, sz = size if len(size) >= 3 else (0.05, 0.05, 0.05) + return [sx / 2, sy / 2, sz / 2] + if shape == "plane": + sx = size[0] if size else 1.0 + sy = size[1] if len(size) > 1 else sx + return [sx, sy, 0.01] + if shape == "mesh": + return [0.0, 0.0, 0.0] + raise ValueError(f"Cannot normalize size for shape {shape!r}.") + + +def _target_quat(position: list[float], target: list[float]) -> list[float] | None: + """Compute the camera orientation quaternion that makes ``position`` look + at ``target`` with world +Z as the up vector. + + Camera convention: + + * Forward (cam local -Z) = normalize(target - position) + * Right (cam local +X) = normalize(forward x up) + * Image-up (cam local +Y) = normalize(right x forward) + + Returns ``None`` for degenerate cases (target == position, or forward + parallel to up). Callers handle the degenerate case upstream. + + Uses MuJoCo's ``mju_mat2Quat`` so no hand-rolled quaternion math. + """ + mujoco = _ensure_mujoco() + + fwd = np.asarray(target, dtype=float) - np.asarray(position, dtype=float) + flen = float(np.linalg.norm(fwd)) + if flen < 1e-9: + return None + fwd /= flen + + up = np.array([0.0, 0.0, 1.0]) + right = np.cross(fwd, up) + rlen = float(np.linalg.norm(right)) + if rlen < 1e-9: + return None + right /= rlen + image_up = np.cross(right, fwd) + image_up /= float(np.linalg.norm(image_up)) + + # Columns of R are [right, image_up, -forward] - the camera's +X, +Y, +Z + # basis vectors expressed in world frame. Row-major layout for MuJoCo. + rot = np.column_stack([right, image_up, -fwd]) + quat = np.zeros(4) + mujoco.mju_mat2Quat(quat, rot.ravel()) + return quat.tolist() + + +# ============================================================================= +# SpecBuilder - the public API +# ============================================================================= + + +class SpecBuilder: + """Builds and mutates ``mujoco.MjSpec`` trees from ``SimWorld`` state. + + Three distinct operations: + + * :meth:`build(world)` - fresh spec from all world contents. Called by + ``Simulation._compile_world`` when first creating a world. + * :meth:`add_object` / :meth:`remove_body` / :meth:`add_camera` - mutate + an existing spec in-place. Caller calls ``spec.recompile(model, data)`` + afterwards to propagate changes. State of unchanged joints is preserved + automatically by MuJoCo. + * :meth:`attach_robot` - compose a robot MJCF/URDF from disk into the + scene spec via ``spec.attach(other, prefix=..., frame=...)``. MuJoCo + handles name prefixing, asset deduplication, and default-class + namespacing natively. + """ + + # --------------------------------------------------------------- full build + @staticmethod + def build(world: SimWorld) -> Any: + """Build a fresh ``mujoco.MjSpec`` reflecting the current ``SimWorld``. + + Produces: + * option (timestep, gravity) + * visual + offscreen framebuffer size + * grid texture/material (for the ground plane) + * mesh assets for any objects with ``shape == "mesh"`` + * lights (``main_light``, ``fill_light``) + * ground plane (if ``world.ground_plane``) + * cameras + * objects + + Robots are NOT included here - they're attached separately via + :meth:`attach_robot` because each attach consumes a fresh MjSpec + loaded from the URDF/MJCF file on disk. + + Caller is responsible for ``spec.compile()`` to produce an MjModel. + """ + mujoco = _ensure_mujoco() + + spec = mujoco.MjSpec() + spec.modelname = "strands_sim" + + # Compiler + simulation options. + spec.compiler.degree = False # radians + spec.compiler.autolimits = True + + spec.option.timestep = float(world.timestep) + spec.option.gravity = list(world.gravity) + + # Offscreen framebuffer - the default 640x480 is too small for common + # camera res. 1280x960 matches what the legacy builder used. + spec.visual.global_.offwidth = 1280 + spec.visual.global_.offheight = 960 + spec.visual.quality.shadowsize = 4096 + + # Ground texture + material - MuJoCo's built-in checkerboard. + grid_tex = spec.add_texture( + name="grid_tex", + type=mujoco.mjtTexture.mjTEXTURE_2D, + builtin=mujoco.mjtBuiltin.mjBUILTIN_CHECKER, + width=512, + height=512, + rgb1=[0.9, 0.9, 0.9], + rgb2=[0.7, 0.7, 0.7], + ) + grid_mat = spec.add_material(name="grid_mat", texrepeat=[8, 8], reflectance=0.1) + grid_mat.textures[mujoco.mjtTextureRole.mjTEXROLE_RGB] = grid_tex.name + + # Mesh assets for objects that declare ``shape == "mesh"``. + for obj in world.objects.values(): + if obj.shape == "mesh" and obj.mesh_path: + spec.add_mesh(name=f"mesh_{obj.name}", file=obj.mesh_path) + + # Lights. + spec.worldbody.add_light( + name="main_light", + pos=[0.0, 0.0, 3.0], + dir=[0.0, 0.0, -1.0], + diffuse=[1.0, 1.0, 1.0], + specular=[0.3, 0.3, 0.3], + ) + spec.worldbody.add_light( + name="fill_light", + pos=[1.0, 1.0, 2.0], + dir=[-0.5, -0.5, -1.0], + diffuse=[0.5, 0.5, 0.5], + ) + + # Ground plane. + if world.ground_plane: + spec.worldbody.add_geom( + name="ground", + type=mujoco.mjtGeom.mjGEOM_PLANE, + size=[5.0, 5.0, 0.01], + material="grid_mat", + conaffinity=1, + condim=3, + ) + + # Cameras. Skip cameras that were discovered inside a robot's URDF - + # they'll come back automatically via ``spec.attach(robot_spec)``. + # Re-adding them at the top level would collide with the attached + # namespaced copy at compile time. + for cam in world.cameras.values(): + if getattr(cam, "origin_robot", ""): + continue + SpecBuilder.add_camera(spec, cam) + + # Objects. + for obj in world.objects.values(): + SpecBuilder.add_object(spec, obj) + + return spec + + # --------------------------------------------------------------- from_mjcf + @staticmethod + def from_mjcf_string(xml: str) -> Any: + """Load an MJCF XML string as a fresh spec. Used by ``replace_scene``. + + Raises ``ValueError`` on malformed XML via MuJoCo's compiler. + """ + mujoco = _ensure_mujoco() + return mujoco.MjSpec.from_string(xml) + + @staticmethod + def from_file(path: str) -> Any: + """Load an MJCF/URDF file as a fresh spec. + + MuJoCo 3.2+ reads URDF as well as MJCF via the same entry point - the + file extension + XML root determines the path. Raises ``ValueError`` + on invalid files. + """ + mujoco = _ensure_mujoco() + return mujoco.MjSpec.from_file(str(path)) + + # ------------------------------------------------------------ object add + @staticmethod + def add_object(spec: Any, obj: SimObject) -> None: + """Add a ``SimObject`` to ``spec.worldbody`` in-place. + + * Dynamic objects (``is_static=False``) get a freejoint + explicit + inertial block (diag 0.001, user-supplied mass) matching the + legacy builder. + * Static objects skip the freejoint and inertial. + * Meshes require a matching ``spec.add_mesh(...)`` to have been + registered (usually by :meth:`build`); this method does NOT + register mesh assets. + """ + body = spec.worldbody.add_body( + name=obj.name, + pos=list(obj.position), + quat=list(obj.orientation), + ) + + if not obj.is_static: + body.add_freejoint(name=f"{obj.name}_joint") + body.mass = float(obj.mass) + body.inertia = [0.001, 0.001, 0.001] + body.ipos = [0.0, 0.0, 0.0] + body.explicitinertial = True + + geom_kwargs: dict[str, Any] = { + "name": f"{obj.name}_geom", + "type": _geom_type(obj.shape), + "rgba": list(obj.color), + "condim": 3, + } + if obj.shape == "mesh": + geom_kwargs["meshname"] = f"mesh_{obj.name}" + else: + geom_kwargs["size"] = _normalize_size(obj.shape, list(obj.size)) + + # Legacy code only set explicit friction on boxes; preserve parity. + if obj.shape == "box": + geom_kwargs["friction"] = [1.0, 0.5, 0.001] + + body.add_geom(**geom_kwargs) + + # ----------------------------------------------------------- camera add + @staticmethod + def add_camera(spec: Any, cam: SimCamera) -> None: + """Add a world-fixed camera. If ``cam.target`` is set, converts the + look-at direction to a quaternion via :func:`_target_quat`. + """ + mujoco = _ensure_mujoco() + pos = list(cam.position) + kwargs: dict[str, Any] = { + "name": cam.name, + "pos": pos, + "fovy": float(cam.fov), + "mode": mujoco.mjtCamLight.mjCAMLIGHT_FIXED, + } + target = getattr(cam, "target", None) + if target is not None: + quat = _target_quat(pos, list(target)) + if quat is not None: + kwargs["quat"] = quat + spec.worldbody.add_camera(**kwargs) + + # --------------------------------------------------------- body remove + @staticmethod + def remove_body(spec: Any, name: str) -> bool: + """Remove a body by name from the spec. + + Uses ``spec.delete(body)`` which walks the spec's typed registry. + Returns ``True`` if the body existed and was removed, ``False`` + otherwise (to match the legacy scene_ops API). + + Note: this removes ONLY the body; any actuators/sensors referencing + its joints must be cleaned up separately via :meth:`remove_refs_by_prefix`. + That's only needed for robots - for plain object bodies there are + no actuators/sensors tied to them. + """ + try: + body = spec.body(name) + except (KeyError, ValueError): + return False + if body is None: + return False + spec.delete(body) + return True + + # ------------------------------------------------------- camera remove + @staticmethod + def remove_camera(spec: Any, name: str) -> bool: + """Remove a camera by name from the spec.""" + # spec.cameras returns the list; find by name + cameras = getattr(spec, "cameras", None) + if cameras is None: + return False + for cam in cameras: + if cam.name == name: + spec.delete(cam) + return True + return False + + # ------------------------------------------------------------- attach + @staticmethod + def attach_robot( + scene_spec: Any, + robot: SimRobot, + robot_file_path: str, + ) -> list[str]: + """Attach a URDF/MJCF file into the scene spec with a name prefix. + + Uses ``spec.attach(other, prefix=..., frame=...)`` which handles + body/joint/geom/actuator/sensor name prefixing automatically, dedups + shared assets (meshes, textures, materials), and namespaces default + classes - replacing ~400 lines of hand-rolled tree-walking from the + legacy ``scene_ops._prefix_robot_names`` + + ``_namespace_robot_default_classes``. + + Args: + scene_spec: the scene spec to mutate. + robot: ``SimRobot`` carrying ``name`` (used as prefix) and + ``position`` / ``orientation`` (used as attach frame). + robot_file_path: absolute or relative path to an MJCF/URDF file. + + Returns: + List of joint names belonging to the attached robot, in the order + MuJoCo discovered them (no prefix - caller namespaces via + ``robot.namespace`` when it resolves IDs post-compile). + """ + mujoco = _ensure_mujoco() + + robot_spec = mujoco.MjSpec.from_file(str(robot_file_path)) + + # Collect source joint names BEFORE attach - attach mutates the child + # spec in-place (the child gets reparented). + source_joint_names: list[str] = [] + + def _walk(body: Any) -> None: + for j in body.joints: + jname = j.name or "" + if jname and jname not in source_joint_names: + source_joint_names.append(jname) + for sub in body.bodies: + _walk(sub) + + for top_body in robot_spec.worldbody.bodies: + _walk(top_body) + + frame = scene_spec.worldbody.add_frame( + pos=list(robot.position), + quat=list(robot.orientation), + ) + scene_spec.attach(robot_spec, prefix=f"{robot.name}/", frame=frame) + + return source_joint_names + + +__all__ = [ + "SpecBuilder", + "_geom_type", + "_normalize_size", + "_target_quat", +] diff --git a/strands_robots/simulation/mujoco/tool_spec.json b/strands_robots/simulation/mujoco/tool_spec.json new file mode 100644 index 0000000..05ad14f --- /dev/null +++ b/strands_robots/simulation/mujoco/tool_spec.json @@ -0,0 +1,364 @@ +{ + "type": "object", + "properties": { + "action": { + "type": "string", + "description": "Action to perform", + "enum": [ + "create_world", + "load_scene", + "replace_scene_mjcf", + "reset", + "get_state", + "destroy", + "add_robot", + "remove_robot", + "list_robots", + "get_robot_state", + "add_object", + "remove_object", + "move_object", + "list_objects", + "add_camera", + "remove_camera", + "run_policy", + "start_policy", + "stop_policy", + "list_policies_running", + "render", + "render_depth", + "get_contacts", + "step", + "set_gravity", + "set_timestep", + "randomize", + "start_recording", + "stop_recording", + "get_recording_status", + "open_viewer", + "close_viewer", + "list_urdfs", + "register_urdf", + "get_features", + "replay_episode", + "eval_policy", + "save_state", + "load_state", + "apply_force", + "raycast", + "multi_raycast", + "get_jacobian", + "get_energy", + "get_mass_matrix", + "inverse_dynamics", + "get_body_state", + "set_joint_positions", + "set_joint_velocities", + "get_sensor_data", + "set_body_properties", + "set_geom_properties", + "get_contact_forces", + "forward_kinematics", + "get_total_mass", + "export_xml", + "render_all", + "start_cameras_recording", + "stop_cameras_recording", + "get_cameras_recording_status" + ] + }, + "scene_path": { + "type": "string", + "description": "Path to MJCF/URDF scene file" + }, + "timestep": { + "type": "number" + }, + "gravity": { + "type": "array", + "items": { + "type": "number" + } + }, + "ground_plane": { + "type": "boolean" + }, + "urdf_path": { + "type": "string", + "description": "Path to URDF/MJCF file" + }, + "robot_name": { + "type": "string" + }, + "data_config": { + "type": "string", + "description": "Data config name (auto-resolves URDF)" + }, + "name": { + "type": "string", + "description": "Object/camera name" + }, + "shape": { + "type": "string", + "enum": [ + "box", + "sphere", + "cylinder", + "capsule", + "mesh", + "plane" + ] + }, + "position": { + "type": "array", + "items": { + "type": "number" + } + }, + "orientation": { + "type": "array", + "items": { + "type": "number" + } + }, + "size": { + "type": "array", + "items": { + "type": "number" + } + }, + "color": { + "type": "array", + "items": { + "type": "number" + } + }, + "mass": { + "type": "number" + }, + "is_static": { + "type": "boolean" + }, + "mesh_path": { + "type": "string" + }, + "target": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Camera target point" + }, + "fov": { + "type": "number", + "description": "Camera field of view" + }, + "width": { + "type": "integer" + }, + "height": { + "type": "integer" + }, + "policy_provider": { + "type": "string", + "description": "Policy provider name (e.g. groot, lerobot_async, lerobot_local, dreamgen, mock)" + }, + "instruction": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "action_horizon": { + "type": "integer" + }, + "control_frequency": { + "type": "number" + }, + "camera_name": { + "type": "string" + }, + "n_steps": { + "type": "integer" + }, + "output_path": { + "type": "string", + "description": "Trajectory/video export path" + }, + "fps": { + "type": "integer", + "description": "Video frames per second (for run_policy record_video)" + }, + "randomize_colors": { + "type": "boolean" + }, + "randomize_lighting": { + "type": "boolean" + }, + "randomize_physics": { + "type": "boolean" + }, + "randomize_positions": { + "type": "boolean" + }, + "position_noise": { + "type": "number" + }, + "seed": { + "type": "integer", + "description": "Random seed" + }, + "repo_id": { + "type": "string", + "description": "HuggingFace dataset repo ID" + }, + "push_to_hub": { + "type": "boolean", + "description": "Auto-push dataset to HuggingFace Hub on stop_recording" + }, + "vcodec": { + "type": "string", + "description": "Video codec for dataset recording (h264, hevc, libsvtav1)" + }, + "task": { + "type": "string", + "description": "Task description for dataset recording" + }, + "episode": { + "type": "integer", + "description": "Episode index for replay_episode" + }, + "root": { + "type": "string", + "description": "Local dataset root directory" + }, + "speed": { + "type": "number", + "description": "Replay speed multiplier (1.0 = original)" + }, + "n_episodes": { + "type": "integer", + "description": "Number of eval episodes" + }, + "max_steps": { + "type": "integer", + "description": "Max steps per eval episode" + }, + "success_fn": { + "type": "string", + "description": "Success function ('contact')" + }, + "fast_mode": { + "type": "boolean", + "description": "Skip sleep between actions for faster data collection" + }, + "body_name": { + "type": "string", + "description": "Target body name" + }, + "site_name": { + "type": "string", + "description": "Site name for Jacobian" + }, + "geom_name": { + "type": "string", + "description": "Geom name" + }, + "geom_id": { + "type": "integer", + "description": "Geom ID (alternative to geom_name)" + }, + "force": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Force vector [fx, fy, fz] in Newtons" + }, + "torque_vec": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Torque vector [tx, ty, tz] in N\u00b7m" + }, + "point": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Point of force application [x, y, z]" + }, + "origin": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Ray origin [x, y, z]" + }, + "direction": { + "type": "array", + "items": { + "type": "number" + }, + "description": "Ray direction [dx, dy, dz]" + }, + "directions": { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "number" + } + }, + "description": "Multiple ray directions for multi_raycast" + }, + "exclude_body": { + "type": "integer", + "description": "Body ID to exclude from raycast (-1=none)" + }, + "include_static": { + "type": "boolean", + "description": "Include static geoms in raycast" + }, + "positions": { + "type": "object", + "description": "Joint name \u2192 position mapping for set_joint_positions" + }, + "velocities": { + "type": "object", + "description": "Joint name \u2192 velocity mapping for set_joint_velocities" + }, + "sensor_name": { + "type": "string", + "description": "Specific sensor name (or omit for all)" + }, + "checkpoint_name": { + "type": "string", + "description": "Named checkpoint for save_state/load_state" + }, + "policy_config": { + "type": "object", + "description": "Provider-specific config dict forwarded to strands_robots.policies.create_policy. Contents depend on policy_provider. For 'groot': host, port, api_token, observation_mapping, action_mapping. For 'lerobot_local': pretrained_name_or_path, device, trust_remote_code, actions_per_step, use_processor, processor_overrides, observation_mapping, action_mapping. For 'mock': {} is fine.", + "additionalProperties": true + }, + "cameras": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of camera names. Omit to use every camera in the scene. Used by render_all / start_cameras_recording." + }, + "output_dir": { + "type": "string", + "description": "Directory for start_cameras_recording output. Defaults to /tmp/strands_robots/recordings." + }, + "xml": { + "type": "string", + "description": "Raw MJCF XML string. Used by replace_scene_mjcf." + } + }, + "required": [ + "action" + ] +} \ No newline at end of file diff --git a/strands_robots/simulation/policy_runner.py b/strands_robots/simulation/policy_runner.py new file mode 100644 index 0000000..b8a67ef --- /dev/null +++ b/strands_robots/simulation/policy_runner.py @@ -0,0 +1,619 @@ +"""Backend-agnostic policy execution against any ``SimEngine``. + +Runs the canonical obs → act → step loop using only the public ``SimEngine`` +interface. Zero knowledge of the underlying physics engine - MuJoCo, Isaac, +Newton and any future backend get ``run_policy`` / ``replay`` / ``evaluate`` +for free by implementing the ``SimEngine`` primitives. + +Three entry points: + +* :meth:`PolicyRunner.run` - blocking policy execution with optional video. +* :meth:`PolicyRunner.replay` - replay a recorded LeRobotDataset episode. +* :meth:`PolicyRunner.evaluate` - multi-episode evaluation with success metrics. + +All three call only these public ``SimEngine`` methods: + +* ``get_observation(robot_name)`` +* ``send_action(action, robot_name, n_substeps)`` +* ``step(n_steps)`` +* ``reset()`` +* ``render(camera_name, width, height)`` + +And two public helpers for robot discovery: + +* ``list_robots()`` - ordered robot names in the world +* ``robot_joint_names(robot_name)`` - ordered joint names for a robot + +Thread safety: ``PolicyRunner`` itself is stateless per invocation. The +underlying ``SimEngine`` is responsible for thread-safety inside its own +methods (e.g. MuJoCo acquires a lock inside ``send_action`` / ``step``). +""" + +from __future__ import annotations + +import logging +import os +import time +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np + +from strands_robots._async_utils import _resolve_coroutine +from strands_robots.utils import require_optional + +if TYPE_CHECKING: + from strands_robots.policies.base import Policy + from strands_robots.simulation.base import SimEngine + +from strands_robots.simulation.models import TrajectoryStep + +logger = logging.getLogger(__name__) + + +# Hook signature: called every control step after send_action. +# on_frame(step_idx, observation, action) -> None +OnFrame = Callable[[int, dict[str, Any], dict[str, Any]], None] + +# Success function: called after each step during evaluate(). +# success_fn(observation) -> bool +SuccessFn = Callable[[dict[str, Any]], bool] + + +def _extract_frame_ndarray(render_result: dict) -> np.ndarray | None: + """Decode the PNG bytes emitted by ``SimEngine.render`` into an ndarray. + + ``render()`` returns the image nested inside a content block as + ``{"image": {"format": "png", "source": {"bytes": }}}``. This + helper walks that structure, decodes the PNG, and returns a (H, W, 3|4) + numpy array. Returns ``None`` if no image is found - the recorder then + skips the frame rather than aborting the rollout. + """ + if not isinstance(render_result, dict): + return None + for block in render_result.get("content", []) or []: + if not isinstance(block, dict): + continue + image = block.get("image") + if not isinstance(image, dict): + continue + source = image.get("source") or {} + png_bytes = source.get("bytes") + if png_bytes is None and source.get("data") is not None: + import base64 + + png_bytes = base64.b64decode(source["data"]) + if not png_bytes: + continue + try: + import io + + from PIL import Image + + return np.asarray(Image.open(io.BytesIO(png_bytes)).convert("RGB")) + except Exception: + return None + return None + + +@dataclass(frozen=True) +class VideoConfig: + """Configuration for optional MP4 recording during :meth:`PolicyRunner.run`. + + Consolidates the five formerly-flat video parameters on + :meth:`SimEngine.run_policy` into one typed object. Recording is an + opt-in feature - if ``path`` is falsy, no recording occurs and the + other fields are ignored. + + Attributes: + path: Output MP4 path. ``None``/empty string → recording disabled. + fps: Frames per second to write. + camera: Camera name to render from. ``None`` → backend default. + width: Render width in pixels. + height: Render height in pixels. + """ + + path: str | None = None + fps: int = 30 + camera: str | None = None + width: int = 640 + height: int = 480 + + @property + def enabled(self) -> bool: + return bool(self.path) + + @classmethod + def from_dict(cls, d: dict[str, Any] | None) -> VideoConfig | None: + """Build from a plain dict (tool_spec dispatcher path). ``None`` passthrough.""" + if not d: + return None + # Accept both canonical keys and legacy/tool_spec aliases. + return cls( + path=d.get("path") or d.get("record_video") or d.get("output_path"), + fps=int(d.get("fps") or d.get("video_fps") or 30), + camera=d.get("camera") or d.get("video_camera") or d.get("camera_name"), + width=int(d.get("width") or d.get("video_width") or 640), + height=int(d.get("height") or d.get("video_height") or 480), + ) + + +# on_frame hooks that raise are logged at WARN - user-provided telemetry is +# not allowed to kill the rollout. BUT if the hook raises on every single step +# (e.g. a recording hook with a typo'd observation key), we'd complete a 500-step +# episode with zero frames written and silently corrupt the dataset. After this +# many *consecutive* failures, the runner raises and fails the episode loudly. +# +# Overridable via the ``max_onframe_failures`` kwarg on ``PolicyRunner.run``. +# See GH #117. +_MAX_CONSECUTIVE_ONFRAME_FAILURES = 5 + + +class CooperativeStop(BaseException): + """Raised by an ``on_frame`` hook to cooperatively stop a run. + + Inherits ``BaseException`` (not ``Exception``) so hook authors don't + accidentally swallow it with a broad ``except Exception``. Re-raised + by ``PolicyRunner.run`` and caught at the top of the loop to return + a normal stopped-early success result. + """ + + +class PolicyRunner: + """Backend-agnostic policy execution against a ``SimEngine``. + + Construct with any ``SimEngine`` and call :meth:`run`, :meth:`replay`, or + :meth:`evaluate`. The runner is stateless across calls - safe to reuse. + + Args: + sim: Any ``SimEngine`` implementation. + """ + + def __init__(self, sim: SimEngine): + self.sim = sim + + # run(): blocking policy execution + def run( + self, + robot_name: str, + policy: Policy, + *, + instruction: str = "", + duration: float = 10.0, + control_frequency: float = 50.0, + action_horizon: int = 8, + fast_mode: bool = False, + video: VideoConfig | None = None, + on_frame: OnFrame | None = None, + max_onframe_failures: int | None = None, + ) -> dict[str, Any]: + """Run ``policy`` on ``robot_name`` for ``duration`` seconds. + + Args: + robot_name: Name of robot in the sim. + policy: Already-constructed ``Policy`` instance. Callers (typically + ``SimEngine.run_policy``) are responsible for policy + construction so tests can inject mocks trivially. + instruction: Natural-language instruction forwarded to the policy. + duration: Wall-clock seconds to run (interpreted as control steps + via ``control_frequency``). + control_frequency: Target Hz for ``policy.get_actions`` calls. + action_horizon: Max actions consumed per policy call before + requerying observation. + fast_mode: If True, skip real-time ``time.sleep`` between steps. + video: Optional :class:`VideoConfig` - set ``video.path`` to enable + MP4 recording via :meth:`SimEngine.render`. + on_frame: Optional hook ``(step_idx, obs, action) -> None`` called + after every ``send_action``. Public extension point - backends + layer in recording / telemetry / graceful-stop via this hook + without subclassing the runner. + max_onframe_failures: Maximum *consecutive* non-``CooperativeStop`` + exceptions from the ``on_frame`` hook before the runner aborts + the episode. ``None`` (default) uses + ``_MAX_CONSECUTIVE_ONFRAME_FAILURES`` (currently ``5``). A + broken recording hook otherwise silently produces empty + datasets - see GH #117. Non-consecutive failures reset the + counter. + + Returns: + ``{"status": "success"|"error", "content": [{"text": ...}]}``. + """ + # Lazy optional import - only imageio is optional. + writer = None + frame_count = 0 + frame_interval = 0.0 + next_frame_step = 0.0 + video_path: str | None = None + if video is not None and video.enabled: + # video.enabled guarantees video.path is a non-empty str; narrow for mypy. + assert video.path is not None + video_path = video.path + imageio = require_optional( + "imageio", + pip_install="imageio imageio-ffmpeg", + extra="sim-mujoco", + purpose="video recording", + ) + os.makedirs(os.path.dirname(os.path.abspath(video_path)), exist_ok=True) + writer = imageio.get_writer( # type: ignore[attr-defined] + video_path, fps=video.fps, quality=8, macro_block_size=1 + ) + frame_interval = control_frequency / video.fps + + stopped_early = False + # T26: skip camera rendering when the policy does not need images. + _skip_images = not getattr(policy, "requires_images", True) + try: + total_steps = int(duration * control_frequency) + action_sleep = 1.0 / control_frequency + start_time = time.time() + step_count = 0 + + onframe_failure_limit = ( + max_onframe_failures if max_onframe_failures is not None else _MAX_CONSECUTIVE_ONFRAME_FAILURES + ) + consecutive_onframe_failures = 0 + while step_count < total_steps: + observation = self.sim.get_observation(robot_name=robot_name, skip_images=_skip_images) + + coro_or_result = policy.get_actions(observation, instruction) + actions = _resolve_coroutine(coro_or_result) + + for action_dict in actions[:action_horizon]: + if step_count >= total_steps: + break + + self.sim.send_action(action_dict, robot_name=robot_name) + + if on_frame is not None: + try: + on_frame(step_count, observation, action_dict) + consecutive_onframe_failures = 0 + except CooperativeStop: + # Backend (e.g. MuJoCo) signalled a graceful stop. + # Break both loops and return a normal success result. + raise + except Exception as e: + # on_frame is user-provided telemetry - never fatal + # *per call*. But if it fails on every step, a 500- + # step episode completes "successfully" with zero + # frames recorded and the dataset is silently empty. + # Count consecutive failures and fail the episode + # after ``onframe_failure_limit`` in a row. See GH #117. + consecutive_onframe_failures += 1 + logger.warning( + "on_frame hook failed (%d/%d consecutive): %s", + consecutive_onframe_failures, + onframe_failure_limit, + e, + ) + if consecutive_onframe_failures >= onframe_failure_limit: + raise RuntimeError( + f"on_frame hook failed {onframe_failure_limit} times in a row; " + f"aborting episode to avoid silent dataset corruption. " + f"Last error: {e!r}" + ) from e + + step_count += 1 + + if writer is not None and step_count >= next_frame_step: + assert video is not None # for mypy: writer only set when video.enabled + frame = self.sim.render( + camera_name=video.camera or "default", + width=video.width, + height=video.height, + ) + # sim.render() returns {status, content:[{text},{image:{source:{bytes}}}]} + # Decode the PNG bytes from the content block and hand an ndarray + # to imageio. Silently skips when the PNG decode fails rather than + # aborting the whole rollout (renderer errors shouldn't kill training). + img_arr = _extract_frame_ndarray(frame) + if img_arr is not None: + writer.append_data(img_arr) + frame_count += 1 + next_frame_step += frame_interval + + if not fast_mode: + time.sleep(action_sleep) + + except CooperativeStop: + stopped_early = True + except Exception as e: + if writer is not None: + writer.close() + logger.exception("PolicyRunner.run failed") + return {"status": "error", "content": [{"text": f"Policy failed: {e}"}]} + + # Either finished all steps or was cooperatively stopped + elapsed = time.time() - start_time + sim_time = self._maybe_sim_time() + prefix = "Policy stopped" if stopped_early else "Policy complete" + text = ( + f"{prefix} on '{robot_name}'\n" + f"🧠 {type(policy).__name__} | 🎯 {instruction}\n" + f"⏱️ {elapsed:.1f}s | 📊 {step_count} steps" + ) + if sim_time is not None: + text += f" | 🕐 sim_t={sim_time:.3f}s" + if writer is not None: + assert video is not None and video_path is not None + writer.close() + if frame_count > 0 and os.path.exists(video_path): + file_kb = os.path.getsize(video_path) / 1024 + text += ( + f"\n🎬 Video: {video_path}\n" + f"📹 {frame_count} frames, {video.fps}fps, " + f"{video.width}x{video.height} | 💾 {file_kb:.0f} KB" + ) + return {"status": "success", "content": [{"text": text}]} + + # replay(): replay a LeRobotDataset episode + + def replay( + self, + repo_id: str, + robot_name: str | None = None, + *, + episode: int = 0, + root: str | None = None, + speed: float = 1.0, + action_key_map: list[str] | None = None, + ) -> dict[str, Any]: + """Replay a recorded LeRobotDataset episode through ``send_action``. + + Args: + repo_id: HuggingFace dataset id (e.g. ``lerobot/pusht``). + robot_name: Target robot. Defaults to first robot in the sim. + episode: Episode index in the dataset. + root: Optional local dataset root override. + speed: Playback speed multiplier (1.0 = real time). + action_key_map: Optional list of joint names, one per action + vector index. Required when dataset joint ordering differs + from ``robot_joint_names(robot_name)``. If ``None``, positional + mapping to ``robot_joint_names`` is used. + + Returns: + Standard status dict with per-frame stats. + """ + try: + from strands_robots.dataset_recorder import load_lerobot_episode + except ImportError: + return {"status": "error", "content": [{"text": "lerobot not installed"}]} + + try: + resolved_robot = robot_name or self._require_default_robot() + except ValueError as e: + return {"status": "error", "content": [{"text": f"{e}"}]} + + try: + ds, episode_start, episode_length = load_lerobot_episode(repo_id, episode, root) + except Exception as e: # noqa: BLE001 - library errors are opaque + return {"status": "error", "content": [{"text": f"{e}"}]} + + # Resolve joint name ordering for action vector index → action dict. + joint_names = list(action_key_map) if action_key_map else self.sim.robot_joint_names(resolved_robot) + + dataset_fps = getattr(ds, "fps", 30) + frame_interval = 1.0 / (dataset_fps * speed) + frames_applied = 0 + start_time = time.time() + + for frame_idx in range(episode_length): + step_start = time.time() + frame = ds[episode_start + frame_idx] + + action_vals = frame.get("action") if isinstance(frame, dict) else None + if action_vals is None: + # No action at this index - just advance physics one step. + self.sim.step(n_steps=1) + frames_applied += 1 + else: + if hasattr(action_vals, "numpy"): + action_vals = action_vals.numpy() + if hasattr(action_vals, "tolist"): + action_vals = action_vals.tolist() + + action_dict: dict[str, Any] = {} + for i, val in enumerate(action_vals): + if i >= len(joint_names): + break + action_dict[joint_names[i]] = float(val) + + self.sim.send_action(action_dict, robot_name=resolved_robot) + frames_applied += 1 + + sleep_time = frame_interval - (time.time() - step_start) + if sleep_time > 0: + time.sleep(sleep_time) + + duration = time.time() - start_time + return { + "status": "success", + "content": [ + { + "text": ( + f"▶️ Replayed episode {episode} from {repo_id} on '{resolved_robot}'\n" + f"Frames: {frames_applied}/{episode_length} | " + f"Duration: {duration:.1f}s | Speed: {speed}x" + ) + }, + { + "json": { + "episode": episode, + "robot_name": resolved_robot, + "frames_applied": frames_applied, + "total_frames": episode_length, + "duration_s": round(duration, 2), + "speed": speed, + } + }, + ], + } + + # evaluate(): multi-episode success metrics + + def evaluate( + self, + robot_name: str, + policy: Policy, + *, + instruction: str = "", + n_episodes: int = 10, + max_steps: int = 300, + success_fn: SuccessFn | str | None = None, + ) -> dict[str, Any]: + """Evaluate ``policy`` for ``n_episodes`` episodes. + + Args: + robot_name: Robot to evaluate. + policy: Already-constructed ``Policy`` instance. + instruction: Instruction forwarded to the policy. + n_episodes: Number of reset → rollout episodes. + max_steps: Cap per episode. + success_fn: Either + + * ``None`` - never succeeds (dry run / performance probe). + * ``"contact"`` - success when ``sim.get_contacts()`` reports + any penetrating contact. Requires backend to implement + ``get_contacts``; falls back to ``False`` otherwise. + * callable ``(observation) -> bool``. + + Returns: + Standard status dict with ``success_rate``, per-episode results. + """ + try: + resolved_check = self._resolve_success_fn(success_fn) + except ValueError as e: + return {"status": "error", "content": [{"text": f"{e}"}]} + + # T26: skip camera rendering when the policy does not need images. + _skip_images = not getattr(policy, "requires_images", True) + results: list[dict[str, Any]] = [] + for ep in range(n_episodes): + self.sim.reset() + success = False + steps = 0 + + for _ in range(max_steps): + observation = self.sim.get_observation(robot_name=robot_name, skip_images=_skip_images) + coro_or_result = policy.get_actions(observation, instruction) + actions = _resolve_coroutine(coro_or_result) + + if actions: + self.sim.send_action(actions[0], robot_name=robot_name) + else: + # Policy returned nothing - still advance one physics step + # so episodes don't hang on degenerate policies. + self.sim.step(n_steps=1) + + steps += 1 + + if resolved_check is not None and resolved_check(observation): + success = True + break + + results.append({"episode": ep, "steps": steps, "success": success}) + + n_success = sum(1 for r in results if r["success"]) + success_rate = n_success / max(n_episodes, 1) + avg_steps = sum(r["steps"] for r in results) / max(n_episodes, 1) + + return { + "status": "success", + "content": [ + { + "text": ( + f"📊 Evaluation: {type(policy).__name__} on '{robot_name}'\n" + f"Episodes: {n_episodes} | Success: {n_success}/{n_episodes} " + f"({success_rate:.1%})\n" + f"Avg steps: {avg_steps:.0f}/{max_steps}" + ) + }, + { + "json": { + "success_rate": round(success_rate, 4), + "n_episodes": n_episodes, + "n_success": n_success, + "avg_steps": round(avg_steps, 1), + "max_steps": max_steps, + "episodes": results, + } + }, + ], + } + + # Helpers + + def _maybe_sim_time(self) -> float | None: + """Best-effort read of sim time from any backend that exposes it. + + Tries two paths: + 1. ``sim._world.sim_time`` - fast path for backends that keep a + structured world object (MuJoCo, and any other backend using + ``strands_robots.simulation.models.SimWorld``). + 2. ``sim.get_state()`` fallback for backends that only expose the + status-dict shape. If the dict's ``json`` block (or top level) + has a ``sim_time`` key, we return it. + """ + world = getattr(self.sim, "_world", None) + if world is not None: + t = getattr(world, "sim_time", None) + if isinstance(t, (int, float)): + return float(t) + + get_state = getattr(self.sim, "get_state", None) + if get_state is None: + return None + try: + state = get_state() + except Exception: + return None + if isinstance(state, dict): + if "sim_time" in state: + return float(state["sim_time"]) + for blk in state.get("content", []): + if isinstance(blk, dict) and isinstance(blk.get("json"), dict): + t = blk["json"].get("sim_time") + if isinstance(t, (int, float)): + return float(t) + return None + + def _require_default_robot(self) -> str: + robots = self.sim.list_robots() + if not robots: + raise ValueError("No robots in sim. Add one first.") + return robots[0] + + def _resolve_success_fn(self, success_fn: SuccessFn | str | None) -> SuccessFn | None: + if success_fn is None: + return None + if callable(success_fn): + return success_fn + if success_fn == "contact": + sim = self.sim + + def _contact_check(_obs: dict[str, Any]) -> bool: + get_contacts = getattr(sim, "get_contacts", None) + if get_contacts is None: + return False + try: + result = get_contacts() + except NotImplementedError: + return False + except Exception: + return False + # Accept either {"contacts": [...]} or {"n_contacts": int} + if isinstance(result, dict): + if result.get("n_contacts", 0) > 0: + return True + contacts = result.get("contacts") + if isinstance(contacts, list) and contacts: + return True + return False + + return _contact_check + raise ValueError(f"Unknown success_fn string: {success_fn!r}") + + +__all__ = ["PolicyRunner", "OnFrame", "SuccessFn", "CooperativeStop", "TrajectoryStep"] diff --git a/strands_robots/tools/download_assets.py b/strands_robots/tools/download_assets.py index 2f59adf..0ea7e7c 100644 --- a/strands_robots/tools/download_assets.py +++ b/strands_robots/tools/download_assets.py @@ -1,4 +1,4 @@ -"""Download robot model assets — Strands Agent ``@tool`` wrapper. +"""Download robot model assets - Strands Agent ``@tool`` wrapper. Thin wrapper around :mod:`strands_robots.assets.download` that exposes ``download_robots()`` as an agent tool. All download logic lives in the @@ -48,30 +48,30 @@ def download_assets( if action == "list": return { "status": "success", - "content": [{"text": f"🤖 Available Robots:\n\n{format_robot_table()}"}], + "content": [{"text": f"Available Robots:\n\n{format_robot_table()}"}], } if action == "status": robots_info = list_available_robots() available = sum(1 for r in robots_info if r["available"]) - lines = [f"📊 {available} available, {len(robots_info) - available} missing"] + lines = [f"{available} available, {len(robots_info) - available} missing"] lines.extend( - f" {'✅' if r['available'] else '❌'} {r['name']:<20s} {r['category']:<12s} {r['description']}" + f"{'' if r['available'] else ''} {r['name']:<20s} {r['category']:<12s} {r['description']}" for r in robots_info ) - lines.append(f"\n📁 Cache: {get_user_assets_dir()}") + lines.append(f"\nCache: {get_user_assets_dir()}") return {"status": "success", "content": [{"text": "\n".join(lines)}]} if action == "download": robot_names = [r.strip() for r in robots.split(",") if r.strip()] if robots else None result = download_robots(names=robot_names, category=category, force=force) parts = [ - f"📦 Downloaded: {result['downloaded']}, Skipped: {result['skipped']}, Failed: {result['failed']}", + f"Downloaded: {result['downloaded']}, Skipped: {result['skipped']}, Failed: {result['failed']}", f"Method: {result.get('method', '?')}", ] if result.get("failed_details"): - parts.extend(f" ❌ {n}: {r}" for n, r in result["failed_details"].items()) - parts.append(f"📁 Assets: {result.get('assets_dir', '?')}") + parts.extend(f" {n}: {r}" for n, r in result["failed_details"].items()) + parts.append(f"Assets: {result.get('assets_dir', '?')}") return {"status": "success", "content": [{"text": "\n".join(parts)}]} return { @@ -81,4 +81,4 @@ def download_assets( except Exception as exc: logger.error("download_assets error: %s", exc) - return {"status": "error", "content": [{"text": f"❌ Error: {exc}"}]} + return {"status": "error", "content": [{"text": f"Error: {exc}"}]} diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index d70b499..aff43ce 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -76,7 +76,7 @@ def gr00t_inference( **Unitree G1 humanoid:** ``unitree_g1``, ``unitree_g1_full_body``, ``unitree_g1_locomanip``, - ``unitree_g1_real`` (N1.7 REAL_G1 embodiment — locomotion + bimanual manipulation) + ``unitree_g1_real`` (N1.7 REAL_G1 embodiment - locomotion + bimanual manipulation) **Franka Panda manipulators:** ``single_panda_gripper``, ``bimanual_panda_gripper``, ``bimanual_panda_hand`` @@ -98,7 +98,7 @@ def gr00t_inference( Set ``use_tensorrt=True`` to enable TensorRT inference. This compiles the model into an optimized engine on first run (may take several minutes). Subsequent runs load from ``trt_engine_path``. Dtype flags (``vit_dtype``, ``llm_dtype``, ``dit_dtype``) - control precision—lower precision (fp8/nvfp4) trades accuracy for speed. + control precision - lower precision (fp8/nvfp4) trades accuracy for speed. Authentication: The ``api_token`` parameter authenticates with the inference service. If omitted, @@ -118,9 +118,9 @@ def gr00t_inference( timeout: Seconds to wait for service startup (default: 60). use_tensorrt: Enable TensorRT acceleration (default: False). trt_engine_path: Directory for TensorRT engine cache (default: ``gr00t_engine``). - vit_dtype: ViT precision with TensorRT—``fp16`` or ``fp8`` (default: ``fp8``). - llm_dtype: LLM precision with TensorRT—``fp16``, ``nvfp4``, or ``fp8`` (default: ``nvfp4``). - dit_dtype: DiT precision with TensorRT—``fp16`` or ``fp8`` (default: ``fp8``). + vit_dtype: ViT precision with TensorRT - ``fp16`` or ``fp8`` (default: ``fp8``). + llm_dtype: LLM precision with TensorRT - ``fp16``, ``nvfp4``, or ``fp8`` (default: ``nvfp4``). + dit_dtype: DiT precision with TensorRT - ``fp16`` or ``fp8`` (default: ``fp8``). http_server: Use HTTP REST API instead of ZMQ (default: False). api_token: API token for authentication. Falls back to ``GROOT_API_TOKEN`` env var. diff --git a/strands_robots/tools/lerobot_calibrate.py b/strands_robots/tools/lerobot_calibrate.py index dde2206..c38614d 100644 --- a/strands_robots/tools/lerobot_calibrate.py +++ b/strands_robots/tools/lerobot_calibrate.py @@ -406,7 +406,7 @@ def lerobot_calibrate( } # Format output - content_lines = ["🔧 **LeRobot Calibrations**", f"📍 Location: `{manager.base_path}`", ""] + content_lines = [" **LeRobot Calibrations**", f"Location: `{manager.base_path}`", ""] total_count = 0 for dev_type, models in structure.items(): @@ -416,13 +416,13 @@ def lerobot_calibrate( if not models: continue - content_lines.append(f"## 📁 **{dev_type.title()}**") + content_lines.append(f"## **{dev_type.title()}**") for model, calibrations in models.items(): if device_model and device_model != model: continue - content_lines.append(f"### 🤖 **{model}** ({len(calibrations)} calibrations)") + content_lines.append(f"### **{model}** ({len(calibrations)} calibrations)") for calib_id in calibrations: info = manager.get_calibration_info(dev_type, model, calib_id) @@ -448,7 +448,7 @@ def lerobot_calibrate( if not all([device_type, device_model, device_id]): return { "status": "error", - "content": [{"text": "❌ **view** action requires: device_type, device_model, and device_id"}], + "content": [{"text": "**view** action requires: device_type, device_model, and device_id"}], } assert device_type is not None and device_model is not None and device_id is not None @@ -456,25 +456,25 @@ def lerobot_calibrate( if not info: return { "status": "error", - "content": [{"text": f"❌ Calibration not found: `{device_type}/{device_model}/{device_id}`"}], + "content": [{"text": f"Calibration not found: `{device_type}/{device_model}/{device_id}`"}], } content_lines = [ - f"🔧 **Calibration Details: `{device_type}/{device_model}/{device_id}`**", - f"📍 **Path:** `{info['path']}`", - f"📅 **Modified:** {info['modified_time'].strftime('%Y-%m-%d %H:%M:%S')}", - f"📏 **Size:** {info['size_bytes']} bytes ({info['size_bytes'] / 1024:.1f} KB)", + f"**Calibration Details: `{device_type}/{device_model}/{device_id}`**", + f"**Path:** `{info['path']}`", + f"**Modified:** {info['modified_time'].strftime('%Y-%m-%d %H:%M:%S')}", + f"**Size:** {info['size_bytes']} bytes ({info['size_bytes'] / 1024:.1f} KB)", "", ] if info.get("data") and isinstance(info["data"], dict): - content_lines.extend([f"🤖 **Motor Configuration** ({info.get('motor_count', 0)} motors)", ""]) + content_lines.extend([f"**Motor Configuration** ({info.get('motor_count', 0)} motors)", ""]) for motor_name, motor_data in info["data"].items(): if isinstance(motor_data, dict): content_lines.extend( [ - f"### ⚙️ **{motor_name}**", + f"### ️ **{motor_name}**", f" - **ID:** {motor_data.get('id', 'N/A')}", f" - **Drive Mode:** {motor_data.get('drive_mode', 'N/A')}", f" - **Homing Offset:** {motor_data.get('homing_offset', 'N/A')}", @@ -493,12 +493,12 @@ def lerobot_calibrate( search_desc = f"query '{query}'" if query else "specified criteria" return { "status": "success", - "content": [{"text": f"🔍 **No calibrations found** matching {search_desc}"}], + "content": [{"text": f"**No calibrations found** matching {search_desc}"}], "results": [], "count": 0, } - content_lines = [f"🔍 **Search Results** ({len(results)} found)", f"📍 Query: `{query or 'all'}`", ""] + content_lines = [f"**Search Results** ({len(results)} found)", f"Query: `{query or 'all'}`", ""] for result in results: modified = result["modified_time"].strftime("%Y-%m-%d %H:%M:%S") @@ -507,7 +507,7 @@ def lerobot_calibrate( content_lines.extend( [ - f"### 🤖 **{result['device_type']}/{result['device_model']}/{result['device_id']}**", + f"### **{result['device_type']}/{result['device_model']}/{result['device_id']}**", f" - **Modified:** {modified}", f" - **Size:** {size_kb:.1f} KB", f" - **Motors:** {motor_info}", @@ -529,14 +529,14 @@ def lerobot_calibrate( if success: content_lines = [ - "💾 **Backup Completed Successfully**", - f"📁 **Location:** `{message}`", - f"📊 **Files copied:** {count}", + " **Backup Completed Successfully**", + f"**Location:** `{message}`", + f"**Files copied:** {count}", "", ] if device_type or device_model or device_id: - content_lines.append("🔍 **Filters applied:**") + content_lines.append(" **Filters applied:**") if device_type: content_lines.append(f" - Device Type: `{device_type}`") if device_model: @@ -551,37 +551,35 @@ def lerobot_calibrate( "files_count": count, } else: - return {"status": "error", "content": [{"text": f"❌ **Backup failed:** {message}"}]} + return {"status": "error", "content": [{"text": f"**Backup failed:** {message}"}]} elif action == "restore": if not backup_dir: - return {"status": "error", "content": [{"text": "❌ **restore** action requires: backup_dir"}]} + return {"status": "error", "content": [{"text": "**restore** action requires: backup_dir"}]} success, message, count = manager.restore_calibrations(Path(backup_dir), overwrite) if success: return { "status": "success", - "content": [ - {"text": f"✅ **{message}**\n📁 From: `{backup_dir}`\n🔄 Overwrite mode: `{overwrite}`"} - ], + "content": [{"text": f"**{message}**\nFrom: `{backup_dir}`\nOverwrite mode: `{overwrite}`"}], "restored_count": count, } else: - return {"status": "error", "content": [{"text": f"❌ **Restore failed:** {message}"}]} + return {"status": "error", "content": [{"text": f"**Restore failed:** {message}"}]} elif action == "delete": if not all([device_type, device_model, device_id]): return { "status": "error", - "content": [{"text": "❌ **delete** action requires: device_type, device_model, and device_id"}], + "content": [{"text": "**delete** action requires: device_type, device_model, and device_id"}], } assert device_type is not None and device_model is not None and device_id is not None if not manager.calibration_exists(device_type, device_model, device_id): return { "status": "error", - "content": [{"text": f"❌ Calibration not found: `{device_type}/{device_model}/{device_id}`"}], + "content": [{"text": f"Calibration not found: `{device_type}/{device_model}/{device_id}`"}], } success = manager.delete_calibration(device_type, device_model, device_id) @@ -589,19 +587,19 @@ def lerobot_calibrate( if success: return { "status": "success", - "content": [{"text": f"🗑️ **Successfully deleted:** `{device_type}/{device_model}/{device_id}`"}], + "content": [{"text": f"️ **Successfully deleted:** `{device_type}/{device_model}/{device_id}`"}], } else: return { "status": "error", - "content": [{"text": f"❌ **Failed to delete:** `{device_type}/{device_model}/{device_id}`"}], + "content": [{"text": f"**Failed to delete:** `{device_type}/{device_model}/{device_id}`"}], } elif action == "analyze": structure = manager.get_calibration_structure() if not any(structure.values()): - return {"status": "success", "content": [{"text": "📊 **No calibrations to analyze**"}], "analysis": {}} + return {"status": "success", "content": [{"text": "**No calibrations to analyze**"}], "analysis": {}} total_calibrations = 0 device_counts = {"teleoperators": 0, "robots": 0} @@ -631,10 +629,10 @@ def lerobot_calibrate( } content_lines = [ - "📊 **Calibration Analysis**", - f"📍 **Base Path:** `{manager.base_path}`", + " **Calibration Analysis**", + f"**Base Path:** `{manager.base_path}`", "", - "### 📈 **Summary Statistics**", + "### **Summary Statistics**", f" - **Total Calibrations:** {total_calibrations}", f" - **Teleoperators:** {device_counts['teleoperators']}", f" - **Robots:** {device_counts['robots']}", @@ -643,7 +641,7 @@ def lerobot_calibrate( ] if model_stats: - content_lines.extend(["### 🤖 **Device Model Breakdown**"]) + content_lines.extend(["### **Device Model Breakdown**"]) for model_key, count in sorted(model_stats.items()): motor_info = "" if model_key in motor_stats: @@ -671,8 +669,8 @@ def lerobot_calibrate( "status": "success", "content": [ { - "text": f"📍 **Calibration Path**\n`{calib_path}`\n\n" - f"{'✅ File exists' if exists else '❌ File does not exist'}" + "text": f"**Calibration Path**\n`{calib_path}`\n\n" + f"{' File exists' if exists else ' File does not exist'}" } ], "path": str(calib_path), @@ -684,7 +682,7 @@ def lerobot_calibrate( "status": "success", "content": [ { - "text": f"📍 **LeRobot Calibration Paths**\n\n" + "text": f"**LeRobot Calibration Paths**\n\n" f"**Base:** `{manager.base_path}`\n" f"**Teleoperators:** `{manager.teleop_path}`\n" f"**Robots:** `{manager.robot_path}`" @@ -700,7 +698,7 @@ def lerobot_calibrate( "status": "error", "content": [ { - "text": f"❌ **Unknown action:** `{action}`\n\n" + "text": f"**Unknown action:** `{action}`\n\n" "Available actions: list, view, search, backup, restore, delete, analyze, path" } ], @@ -708,4 +706,4 @@ def lerobot_calibrate( except Exception as e: logger.error(f"LeRobot calibrate tool error: {e}") - return {"status": "error", "content": [{"text": f"❌ **Tool execution failed:** {str(e)}"}]} + return {"status": "error", "content": [{"text": f"**Tool execution failed:** {str(e)}"}]} diff --git a/strands_robots/tools/lerobot_camera.py b/strands_robots/tools/lerobot_camera.py index 59a5a4e..fb42dd2 100644 --- a/strands_robots/tools/lerobot_camera.py +++ b/strands_robots/tools/lerobot_camera.py @@ -73,7 +73,7 @@ def _frame_to_image_content(frame: np.ndarray, format: str = "jpg") -> dict[str, except Exception as e: logger.error(f"Failed to convert frame to image content: {e}") - return {"text": f"❌ Failed to encode image: {str(e)}"} + return {"text": f"Failed to encode image: {str(e)}"} @tool @@ -140,7 +140,7 @@ def lerobot_camera( if camera_id is None: return { "status": "error", - "content": [{"text": "❌ camera_id required for capture action"}], + "content": [{"text": "camera_id required for capture action"}], } return _capture_single_image( camera_type, @@ -179,7 +179,7 @@ def lerobot_camera( if camera_id is None: return { "status": "error", - "content": [{"text": "❌ camera_id required for record action"}], + "content": [{"text": "camera_id required for record action"}], } return _record_video_sequence( camera_type, @@ -199,7 +199,7 @@ def lerobot_camera( if camera_id is None: return { "status": "error", - "content": [{"text": "❌ camera_id required for preview action"}], + "content": [{"text": "camera_id required for preview action"}], } return _preview_camera_live( camera_type, @@ -218,7 +218,7 @@ def lerobot_camera( if camera_id is None: return { "status": "error", - "content": [{"text": "❌ camera_id required for test action"}], + "content": [{"text": "camera_id required for test action"}], } return _test_camera_performance( camera_type, @@ -236,7 +236,7 @@ def lerobot_camera( if camera_id is None: return { "status": "error", - "content": [{"text": "❌ camera_id required for configure action"}], + "content": [{"text": "camera_id required for configure action"}], } return _configure_camera_settings( camera_type, @@ -253,13 +253,13 @@ def lerobot_camera( else: return { "status": "error", - "content": [{"text": f"❌ Unknown action: {action}"}], + "content": [{"text": f"Unknown action: {action}"}], } except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Camera operation failed: {str(e)}"}], + "content": [{"text": f"Camera operation failed: {str(e)}"}], } @@ -281,10 +281,10 @@ def _discover_cameras() -> dict[str, Any]: # Format discovery results discovery_info = [] - discovery_info.append("🔍 **Camera Discovery Results**\n") + discovery_info.append(" **Camera Discovery Results**\n") if opencv_cameras: - discovery_info.append("📹 **OpenCV Cameras:**") + discovery_info.append(" **OpenCV Cameras:**") for i, cam in enumerate(opencv_cameras): profile = cam.get("default_stream_profile", {}) discovery_info.append( @@ -298,7 +298,7 @@ def _discover_cameras() -> dict[str, Any]: discovery_info.append("") if realsense_cameras: - discovery_info.append("🎯 **RealSense Cameras:**") + discovery_info.append(" **RealSense Cameras:**") for i, cam in enumerate(realsense_cameras): discovery_info.append( f" • **{cam.get('name', 'Unknown')}**\n" @@ -308,9 +308,9 @@ def _discover_cameras() -> dict[str, Any]: discovery_info.append("") if total_cameras == 0: - discovery_info.append("❌ **No cameras detected**") + discovery_info.append(" **No cameras detected**") else: - discovery_info.append(f"✅ **Total: {total_cameras} cameras found**") + discovery_info.append(f"**Total: {total_cameras} cameras found**") discovery_info.append(f" - OpenCV: {len(opencv_cameras)}") discovery_info.append(f" - RealSense: {len(realsense_cameras)}") @@ -319,7 +319,7 @@ def _discover_cameras() -> dict[str, Any]: except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Camera discovery failed: {str(e)}"}], + "content": [{"text": f"Camera discovery failed: {str(e)}"}], } @@ -327,15 +327,15 @@ def _list_camera_details(camera_type: str, camera_id: int | str | None = None) - """List detailed camera information and configurations.""" try: details = [] - details.append("📋 **Camera Configuration Details**\n") + details.append(" **Camera Configuration Details**\n") if camera_type.lower() == "opencv": - details.append("🎥 **OpenCV Camera System:**") + details.append(" **OpenCV Camera System:**") details.append(f" - Backend: {_get_opencv_backend_name()}") details.append(f" - Version: {cv2.__version__}") details.append(" - Available color modes: RGB, BGR") details.append(" - Supported rotations: 0°, 90°, 180°, 270°") - details.append(" - Async reading: ✅ Supported") + details.append(" - Async reading: Supported") details.append("") if camera_id is not None: @@ -344,8 +344,8 @@ def _list_camera_details(camera_type: str, camera_id: int | str | None = None) - camera = OpenCVCamera(config) camera.connect(warmup=False) - details.append(f"📸 **Camera {camera_id} Details:**") - details.append(" - Connection: ✅ Success") + details.append(f"**Camera {camera_id} Details:**") + details.append(" - Connection: Success") details.append(f" - Actual FPS: {camera.fps}") details.append(f" - Resolution: {camera.width}x{camera.height}") details.append(f" - Color Mode: {camera.color_mode.value}") @@ -353,30 +353,30 @@ def _list_camera_details(camera_type: str, camera_id: int | str | None = None) - camera.disconnect() except Exception as e: - details.append(f"📸 **Camera {camera_id} Details:**") - details.append(f" - Connection: ❌ Failed ({str(e)})") + details.append(f"**Camera {camera_id} Details:**") + details.append(f" - Connection: Failed ({str(e)})") elif camera_type.lower() == "realsense" and REALSENSE_AVAILABLE: - details.append("🎯 **RealSense Camera System:**") - details.append(" - SDK Available: ✅ Yes") - details.append(" - Depth Support: ✅ Yes") + details.append(" **RealSense Camera System:**") + details.append(" - SDK Available: Yes") + details.append(" - Depth Support: Yes") details.append(" - Multiple streams: Color, Depth, Infrared") details.append(" - Advanced features: Post-processing, alignment") else: if not REALSENSE_AVAILABLE and camera_type.lower() == "realsense": - details.append("🎯 **RealSense Camera System:**") - details.append(" - SDK Available: ❌ Not installed") + details.append(" **RealSense Camera System:**") + details.append(" - SDK Available: Not installed") details.append(" - Install with: `pip install pyrealsense2`") else: - details.append(f"❌ **Unknown camera type: {camera_type}**") + details.append(f"**Unknown camera type: {camera_type}**") return {"status": "success", "content": [{"text": "\n".join(details)}]} except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Camera details failed: {str(e)}"}], + "content": [{"text": f"Camera details failed: {str(e)}"}], } @@ -434,7 +434,7 @@ def _capture_single_image( if not success: return { "status": "error", - "content": [{"text": f"❌ Failed to save image: {file_path}"}], + "content": [{"text": f"Failed to save image: {file_path}"}], } # Get image info @@ -442,15 +442,15 @@ def _capture_single_image( file_size = os.path.getsize(file_path) result_info = [ - "📸 **Image Capture Success!**", - f"🎥 Camera: {camera_type.upper()} @ {camera_id}", - f"💾 Saved: `{file_path}`", - f"📐 Resolution: {img_width}x{img_height}", - f"💿 File size: {file_size:,} bytes", - f"⚡ Connect time: {connect_time:.3f}s", - f"📷 Capture time: {capture_time:.3f}s", - f"🔄 Async mode: {'✅' if async_mode else '❌'}", - f"🕐 Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", + " **Image Capture Success!**", + f"Camera: {camera_type.upper()} @ {camera_id}", + f"Saved: `{file_path}`", + f"Resolution: {img_width}x{img_height}", + f"File size: {file_size:,} bytes", + f"Connect time: {connect_time:.3f}s", + f"Capture time: {capture_time:.3f}s", + f"Async mode: {'' if async_mode else ''}", + f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ] # Create image content for Converse API @@ -464,7 +464,7 @@ def _capture_single_image( except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Image capture failed: {str(e)}"}], + "content": [{"text": f"Image capture failed: {str(e)}"}], } @@ -555,13 +555,13 @@ def capture_single_camera(cam_id): total_time = time.time() - total_time # Format results and prepare content list - result_info = ["📸 **Batch Camera Capture Results:**", ""] + result_info = [" **Batch Camera Capture Results:**", ""] content_list = [] for result in results: if result["status"] == "success": result_info.append( - f"✅ **{result['camera_id']}**: {result['resolution']} " + f"**{result['camera_id']}**: {result['resolution']} " f"({result['file_size']:,} bytes, {result['capture_time']:.3f}s)" ) # Add image content if frame is available @@ -569,16 +569,16 @@ def capture_single_camera(cam_id): image_content = _frame_to_image_content(result["frame"], format) content_list.append(image_content) else: - result_info.append(f"❌ **{result['camera_id']}**: {result['message']}") + result_info.append(f"**{result['camera_id']}**: {result['message']}") result_info.extend( [ "", - "📊 **Summary:**", + " **Summary:**", f" - Success: {successful_captures}/{len(camera_ids)} cameras", f" - Total time: {total_time:.3f}s", f" - Save path: `{save_path}`", - f" - Async mode: {'✅' if async_mode else '❌'}", + f" - Async mode: {'' if async_mode else ''}", ] ) @@ -593,7 +593,7 @@ def capture_single_camera(cam_id): except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Batch capture failed: {str(e)}"}], + "content": [{"text": f"Batch capture failed: {str(e)}"}], } @@ -662,15 +662,15 @@ def _record_video_sequence( file_size = os.path.getsize(video_path) result_info = [ - "🎬 **Video Recording Complete!**", - f"🎥 Camera: {camera_type.upper()} @ {camera_id}", - f"💾 Saved: `{video_path}`", - f"📐 Resolution: {width}x{height}", - f"🎞️ Frames: {frames_captured} @ {fps} FPS", - f"⏱️ Duration: {actual_duration:.2f}s (target: {capture_duration:.2f}s)", - f"💿 File size: {file_size:,} bytes", - f"🔄 Async mode: {'✅' if async_mode else '❌'}", - f"🕐 Completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", + " **Video Recording Complete!**", + f"Camera: {camera_type.upper()} @ {camera_id}", + f"Saved: `{video_path}`", + f"Resolution: {width}x{height}", + f"️ Frames: {frames_captured} @ {fps} FPS", + f"️ Duration: {actual_duration:.2f}s (target: {capture_duration:.2f}s)", + f"File size: {file_size:,} bytes", + f"Async mode: {'' if async_mode else ''}", + f"Completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ] return {"status": "success", "content": [{"text": "\n".join(result_info)}]} @@ -678,7 +678,7 @@ def _record_video_sequence( except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Video recording failed: {str(e)}"}], + "content": [{"text": f"Video recording failed: {str(e)}"}], } @@ -705,8 +705,8 @@ def _preview_camera_live( fps_counter_start = time.time() fps_frame_count = 0 - print(f"🎥 Starting live preview from {camera_type.upper()} camera {camera_id}") - print(f"⏱️ Duration: {preview_duration}s | Press 'q' to quit early") + print(f"Starting live preview from {camera_type.upper()} camera {camera_id}") + print(f"️ Duration: {preview_duration}s | Press 'q' to quit early") try: while time.time() - start_time < preview_duration: @@ -740,13 +740,13 @@ def _preview_camera_live( # Calculate and display FPS every second if time.time() - fps_counter_start >= 1.0: actual_fps = fps_frame_count / (time.time() - fps_counter_start) - print(f"📊 Live FPS: {actual_fps:.1f} | Frames: {frames_displayed}") + print(f"Live FPS: {actual_fps:.1f} | Frames: {frames_displayed}") fps_counter_start = time.time() fps_frame_count = 0 # Check for quit key if cv2.waitKey(1) & 0xFF == ord("q"): - print("👋 Preview stopped by user") + print(" Preview stopped by user") break # Maintain target FPS @@ -763,14 +763,14 @@ def _preview_camera_live( avg_fps = frames_displayed / actual_duration if actual_duration > 0 else 0 result_info = [ - "📺 **Live Preview Complete!**", - f"🎥 Camera: {camera_type.upper()} @ {camera_id}", - f"📐 Resolution: {width}x{height}", - f"🎞️ Frames displayed: {frames_displayed}", - f"⏱️ Duration: {actual_duration:.2f}s", - f"📊 Average FPS: {avg_fps:.2f}", - f"🎯 Target FPS: {fps}", - f"🔄 Async mode: {'✅' if async_mode else '❌'}", + " **Live Preview Complete!**", + f"Camera: {camera_type.upper()} @ {camera_id}", + f"Resolution: {width}x{height}", + f"️ Frames displayed: {frames_displayed}", + f"️ Duration: {actual_duration:.2f}s", + f"Average FPS: {avg_fps:.2f}", + f"Target FPS: {fps}", + f"Async mode: {'' if async_mode else ''}", ] return {"status": "success", "content": [{"text": "\n".join(result_info)}]} @@ -778,7 +778,7 @@ def _preview_camera_live( except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Preview failed: {str(e)}"}], + "content": [{"text": f"Preview failed: {str(e)}"}], } @@ -797,7 +797,7 @@ def _test_camera_performance( """Test camera performance and capabilities.""" try: test_results = [] - test_results.append("🧪 **Camera Performance Test**\n") + test_results.append(" **Camera Performance Test**\n") # Connection test start_time = time.time() @@ -805,7 +805,7 @@ def _test_camera_performance( camera.connect(warmup=warmup) connect_time = time.time() - start_time - test_results.append(f"✅ **Connection Test**: {connect_time:.3f}s") + test_results.append(f"**Connection Test**: {connect_time:.3f}s") # Frame capture test (sync) capture_times = [] @@ -819,7 +819,7 @@ def _test_camera_performance( min_sync_time = np.min(capture_times) max_sync_time = np.max(capture_times) - test_results.append("📷 **Sync Capture (10 frames)**:") + test_results.append(" **Sync Capture (10 frames)**:") test_results.append(f" - Average: {avg_sync_time:.3f}s") test_results.append(f" - Min: {min_sync_time:.3f}s") test_results.append(f" - Max: {max_sync_time:.3f}s") @@ -838,7 +838,7 @@ def _test_camera_performance( min_async_time = np.min(async_times) max_async_time = np.max(async_times) - test_results.append("⚡ **Async Capture (10 frames)**:") + test_results.append(" **Async Capture (10 frames)**:") test_results.append(f" - Average: {avg_async_time:.3f}s") test_results.append(f" - Min: {min_async_time:.3f}s") test_results.append(f" - Max: {max_async_time:.3f}s") @@ -846,7 +846,7 @@ def _test_camera_performance( test_results.append(f" - Speedup: {avg_sync_time / avg_async_time:.2f}x") # Frame properties test - test_results.append("📊 **Frame Properties**:") + test_results.append(" **Frame Properties**:") test_results.append(f" - Resolution: {frame.shape[1]}x{frame.shape[0]}") test_results.append(f" - Channels: {frame.shape[2]}") test_results.append(f" - Data type: {frame.dtype}") @@ -854,31 +854,29 @@ def _test_camera_performance( # Camera properties if hasattr(camera, "fps"): - test_results.append("⚙️ **Camera Configuration**:") + test_results.append("️ **Camera Configuration**:") test_results.append(f" - Configured FPS: {camera.fps}") test_results.append(f" - Resolution: {camera.width}x{camera.height}") test_results.append(f" - Color mode: {camera.color_mode.value}") camera.disconnect() - test_results.append("\n🎯 **Performance Summary**:") - test_results.append(f" - Connection: {'✅ Fast' if connect_time < 1.0 else '⚠️ Slow'} ({connect_time:.3f}s)") - test_results.append( - f" - Sync capture: {'✅ Good' if avg_sync_time < 0.1 else '⚠️ Slow'} ({avg_sync_time:.3f}s)" - ) + test_results.append("\n **Performance Summary**:") + test_results.append(f" - Connection: {' Fast' if connect_time < 1.0 else '️ Slow'} ({connect_time:.3f}s)") + test_results.append(f" - Sync capture: {' Good' if avg_sync_time < 0.1 else '️ Slow'} ({avg_sync_time:.3f}s)") if async_mode: test_results.append( - f" - Async capture: {'✅ Better' if avg_async_time < avg_sync_time else '❌ Worse'}" - f" ({avg_async_time:.3f}s)" + f" - Async capture: {' Better' if avg_async_time < avg_sync_time else ' Worse'}" + f"({avg_async_time:.3f}s)" ) - test_results.append(f" - Frame rate: {'✅ Stable' if max_sync_time - min_sync_time < 0.05 else '⚠️ Variable'}") + test_results.append(f" - Frame rate: {' Stable' if max_sync_time - min_sync_time < 0.05 else '️ Variable'}") return {"status": "success", "content": [{"text": "\n".join(test_results)}]} except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Performance test failed: {str(e)}"}], + "content": [{"text": f"Performance test failed: {str(e)}"}], } @@ -915,13 +913,13 @@ def _configure_camera_settings( actual_config["rotation"] = rotation config_info = [ - "⚙️ **Camera Configuration**", - f"🎥 Camera: {camera_type.upper()} @ {camera_id}", - f"📐 Resolution: {actual_config['width']}x{actual_config['height']}", - f"🎞️ FPS: {actual_config['fps']}", - f"🎨 Color mode: {actual_config['color_mode']}", - f"🔄 Rotation: {actual_config.get('rotation', 'NO_ROTATION')}", - f"🔧 Warmup: {'✅' if warmup else '❌'}", + "️ **Camera Configuration**", + f"Camera: {camera_type.upper()} @ {camera_id}", + f"Resolution: {actual_config['width']}x{actual_config['height']}", + f"️ FPS: {actual_config['fps']}", + f"Color mode: {actual_config['color_mode']}", + f"Rotation: {actual_config.get('rotation', 'NO_ROTATION')}", + f"Warmup: {'' if warmup else ''}", ] # Save configuration if requested @@ -939,7 +937,7 @@ def _configure_camera_settings( config_info.extend( [ "", - "💾 **Configuration Saved**:", + " **Configuration Saved**:", f" - File: `{config_path}`", " - Format: JSON", ] @@ -952,7 +950,7 @@ def _configure_camera_settings( except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Configuration failed: {str(e)}"}], + "content": [{"text": f"Configuration failed: {str(e)}"}], } diff --git a/strands_robots/tools/lerobot_teleoperate.py b/strands_robots/tools/lerobot_teleoperate.py index 56e0c87..a847d21 100644 --- a/strands_robots/tools/lerobot_teleoperate.py +++ b/strands_robots/tools/lerobot_teleoperate.py @@ -456,7 +456,7 @@ def lerobot_teleoperate( # Check if session already exists if session_manager.get_session(session_name): - return {"status": "error", "content": [{"text": f"❌ Session '{session_name}' already exists"}]} + return {"status": "error", "content": [{"text": f"Session '{session_name}' already exists"}]} # Build command try: @@ -489,7 +489,7 @@ def lerobot_teleoperate( play_sounds=play_sounds, ) except Exception as e: - return {"status": "error", "content": [{"text": f"❌ Command build failed: {str(e)}"}]} + return {"status": "error", "content": [{"text": f"Command build failed: {str(e)}"}]} if background: # Start in background @@ -586,15 +586,15 @@ def auto_respond(): elif action == "stop": if not session_name: - return {"status": "error", "content": [{"text": "❌ Session name required for stop action"}]} + return {"status": "error", "content": [{"text": "Session name required for stop action"}]} session_info = session_manager.get_session(session_name) # type: ignore[assignment] # narrow Optional if not session_info: - return {"status": "error", "content": [{"text": f"❌ Session '{session_name}' not found"}]} + return {"status": "error", "content": [{"text": f"Session '{session_name}' not found"}]} pid = session_info.get("pid") if not pid: - return {"status": "error", "content": [{"text": f"❌ No PID found for session '{session_name}'"}]} + return {"status": "error", "content": [{"text": f"No PID found for session '{session_name}'"}]} pid_int = int(pid) try: @@ -610,7 +610,7 @@ def auto_respond(): return { "status": "success", - "content": [{"text": f"🛑 **Session Stopped**\n📝 Session: `{session_name}`\n🆔 PID: {pid}"}], + "content": [{"text": f"**Session Stopped**\n📝 Session: `{session_name}`\n🆔 PID: {pid}"}], "session_name": session_name, "session_info": session_info, } @@ -620,13 +620,13 @@ def auto_respond(): session_manager.remove_session(session_name) return { "status": "success", - "content": [{"text": f"✅ Session '{session_name}' was already stopped"}], + "content": [{"text": f"Session '{session_name}' was already stopped"}], "session_name": session_name, } except Exception as e: return { "status": "error", - "content": [{"text": f"❌ Failed to stop session '{session_name}': {str(e)}"}], + "content": [{"text": f"Failed to stop session '{session_name}': {str(e)}"}], } elif action == "list": @@ -665,11 +665,11 @@ def auto_respond(): elif action == "status": if not session_name: - return {"status": "error", "content": [{"text": "❌ Session name required for status action"}]} + return {"status": "error", "content": [{"text": "Session name required for status action"}]} session_info = session_manager.get_session(session_name) # type: ignore[assignment] # narrow Optional if not session_info: - return {"status": "error", "content": [{"text": f"❌ Session '{session_name}' not found"}]} + return {"status": "error", "content": [{"text": f"Session '{session_name}' not found"}]} pid = session_info.get("pid") start_time: float = float(session_info.get("start_time") or 0) @@ -715,7 +715,7 @@ def auto_respond(): elif action == "replay": if not dataset_repo_id: - return {"status": "error", "content": [{"text": "❌ dataset_repo_id required for replay action"}]} + return {"status": "error", "content": [{"text": "dataset_repo_id required for replay action"}]} try: cmd = build_lerobot_command( @@ -730,7 +730,7 @@ def auto_respond(): display_data=display_data, ) except Exception as e: - return {"status": "error", "content": [{"text": f"❌ Replay command build failed: {str(e)}"}]} + return {"status": "error", "content": [{"text": f"Replay command build failed: {str(e)}"}]} # Execute replay result = subprocess.run(cmd, capture_output=True, text=True) @@ -757,8 +757,8 @@ def auto_respond(): } else: - return {"status": "error", "content": [{"text": f"❌ Unknown action: {action}"}]} + return {"status": "error", "content": [{"text": f"Unknown action: {action}"}]} except Exception as e: logger.error(f"LeRobot teleoperate error: {e}") - return {"status": "error", "content": [{"text": f"❌ Tool execution failed: {str(e)}"}]} + return {"status": "error", "content": [{"text": f"Tool execution failed: {str(e)}"}]} diff --git a/strands_robots/tools/pose_tool.py b/strands_robots/tools/pose_tool.py index 5098195..dc23858 100644 --- a/strands_robots/tools/pose_tool.py +++ b/strands_robots/tools/pose_tool.py @@ -407,11 +407,11 @@ def pose_tool( if action == "show_pose": if not pose_name: - return {"status": "error", "content": [{"text": "❌ pose_name required"}]} + return {"status": "error", "content": [{"text": "pose_name required"}]} pose = pose_manager.get_pose(pose_name) if not pose: - return {"status": "error", "content": [{"text": f"❌ Pose '{pose_name}' not found"}]} + return {"status": "error", "content": [{"text": f"Pose '{pose_name}' not found"}]} motor_info = "\n".join([f" • {motor}: {pos:.2f}°" for motor, pos in pose.positions.items()]) @@ -430,16 +430,16 @@ def pose_tool( if action == "delete_pose": if not pose_name: - return {"status": "error", "content": [{"text": "❌ pose_name required"}]} + return {"status": "error", "content": [{"text": "pose_name required"}]} if pose_manager.delete_pose(pose_name): - return {"status": "success", "content": [{"text": f"✅ Deleted pose '{pose_name}'"}]} + return {"status": "success", "content": [{"text": f"Deleted pose '{pose_name}'"}]} else: - return {"status": "error", "content": [{"text": f"❌ Pose '{pose_name}' not found"}]} + return {"status": "error", "content": [{"text": f"Pose '{pose_name}' not found"}]} # Actions that need motor controller if not port: - return {"status": "error", "content": [{"text": "❌ port required for motor operations"}]} + return {"status": "error", "content": [{"text": "port required for motor operations"}]} controller = MotorController(port) @@ -447,17 +447,17 @@ def pose_tool( connected, error = controller.connect() if connected: controller.disconnect() - return {"status": "success", "content": [{"text": f"✅ Successfully connected to robot on {port}"}]} + return {"status": "success", "content": [{"text": f"Successfully connected to robot on {port}"}]} else: - return {"status": "error", "content": [{"text": f"❌ {error}"}]} + return {"status": "error", "content": [{"text": f"{error}"}]} if action == "read_position": if not motor_name: - return {"status": "error", "content": [{"text": "❌ motor_name required"}]} + return {"status": "error", "content": [{"text": "motor_name required"}]} connected, error = controller.connect() if not connected: - return {"status": "error", "content": [{"text": f"❌ {error}"}]} + return {"status": "error", "content": [{"text": f"{error}"}]} try: position = controller.read_motor_position(motor_name) @@ -469,14 +469,14 @@ def pose_tool( "position": position, } else: - return {"status": "error", "content": [{"text": f"❌ Failed to read {motor_name}"}]} + return {"status": "error", "content": [{"text": f"Failed to read {motor_name}"}]} finally: controller.disconnect() if action == "read_all": connected, error = controller.connect() if not connected: - return {"status": "error", "content": [{"text": f"❌ {error}"}]} + return {"status": "error", "content": [{"text": f"{error}"}]} try: positions = controller.read_all_positions() @@ -493,22 +493,22 @@ def pose_tool( "positions": positions, } else: - return {"status": "error", "content": [{"text": "❌ Failed to read positions"}]} + return {"status": "error", "content": [{"text": "Failed to read positions"}]} finally: controller.disconnect() if action == "store_pose": if not pose_name: - return {"status": "error", "content": [{"text": "❌ pose_name required"}]} + return {"status": "error", "content": [{"text": "pose_name required"}]} connected, error = controller.connect() if not connected: - return {"status": "error", "content": [{"text": f"❌ {error}"}]} + return {"status": "error", "content": [{"text": f"{error}"}]} try: current_positions = controller.read_all_positions() if not current_positions: - return {"status": "error", "content": [{"text": "❌ Failed to read current positions"}]} + return {"status": "error", "content": [{"text": "Failed to read current positions"}]} pose = pose_manager.store_pose(pose_name, current_positions, description) @@ -529,20 +529,20 @@ def pose_tool( if action == "load_pose": if not pose_name: - return {"status": "error", "content": [{"text": "❌ pose_name required"}]} + return {"status": "error", "content": [{"text": "pose_name required"}]} pose = pose_manager.get_pose(pose_name) if not pose: - return {"status": "error", "content": [{"text": f"❌ Pose '{pose_name}' not found"}]} + return {"status": "error", "content": [{"text": f"Pose '{pose_name}' not found"}]} # Validate pose is_valid, msg = pose_manager.validate_pose(pose) if not is_valid: - return {"status": "error", "content": [{"text": f"❌ Pose validation failed: {msg}"}]} + return {"status": "error", "content": [{"text": f"Pose validation failed: {msg}"}]} connected, error = controller.connect() if not connected: - return {"status": "error", "content": [{"text": f"❌ {error}"}]} + return {"status": "error", "content": [{"text": f"{error}"}]} try: success = controller.move_multiple_motors(pose.positions, smooth) @@ -553,17 +553,17 @@ def pose_tool( "target_positions": pose.positions, } else: - return {"status": "error", "content": [{"text": f"❌ Failed to move to pose '{pose_name}'"}]} + return {"status": "error", "content": [{"text": f"Failed to move to pose '{pose_name}'"}]} finally: controller.disconnect() if action == "move_motor": if not motor_name or position is None: - return {"status": "error", "content": [{"text": "❌ motor_name and position required"}]} + return {"status": "error", "content": [{"text": "motor_name and position required"}]} connected, error = controller.connect() if not connected: - return {"status": "error", "content": [{"text": f"❌ {error}"}]} + return {"status": "error", "content": [{"text": f"{error}"}]} try: success = controller.move_motor(motor_name, position) @@ -571,17 +571,17 @@ def pose_tool( unit = "%" if motor_name == "gripper" else "°" return {"status": "success", "content": [{"text": f"🎯 Moved {motor_name} to {position}{unit}"}]} else: - return {"status": "error", "content": [{"text": f"❌ Failed to move {motor_name}"}]} + return {"status": "error", "content": [{"text": f"Failed to move {motor_name}"}]} finally: controller.disconnect() if action == "move_multiple": if not positions: - return {"status": "error", "content": [{"text": "❌ positions dict required"}]} + return {"status": "error", "content": [{"text": "positions dict required"}]} connected, error = controller.connect() if not connected: - return {"status": "error", "content": [{"text": f"❌ {error}"}]} + return {"status": "error", "content": [{"text": f"{error}"}]} try: success = controller.move_multiple_motors(positions, smooth) @@ -594,17 +594,17 @@ def pose_tool( ) return {"status": "success", "content": [{"text": f"🎯 Moved multiple motors:\n{pos_text}"}]} else: - return {"status": "error", "content": [{"text": "❌ Failed to move motors"}]} + return {"status": "error", "content": [{"text": "Failed to move motors"}]} finally: controller.disconnect() if action == "incremental_move": if not motor_name or delta is None: - return {"status": "error", "content": [{"text": "❌ motor_name and delta required"}]} + return {"status": "error", "content": [{"text": "motor_name and delta required"}]} connected, error = controller.connect() if not connected: - return {"status": "error", "content": [{"text": f"❌ {error}"}]} + return {"status": "error", "content": [{"text": f"{error}"}]} try: success = controller.incremental_move(motor_name, delta) @@ -613,7 +613,7 @@ def pose_tool( sign = "+" if delta >= 0 else "" return {"status": "success", "content": [{"text": f"🔧 Moved {motor_name} by {sign}{delta}{unit}"}]} else: - return {"status": "error", "content": [{"text": f"❌ Failed to move {motor_name}"}]} + return {"status": "error", "content": [{"text": f"Failed to move {motor_name}"}]} finally: controller.disconnect() @@ -630,7 +630,7 @@ def pose_tool( connected, error = controller.connect() if not connected: - return {"status": "error", "content": [{"text": f"❌ {error}"}]} + return {"status": "error", "content": [{"text": f"{error}"}]} try: success = controller.move_multiple_motors(home_positions, smooth=True) @@ -641,20 +641,20 @@ def pose_tool( "home_positions": home_positions, } else: - return {"status": "error", "content": [{"text": "❌ Failed to move to home position"}]} + return {"status": "error", "content": [{"text": "Failed to move to home position"}]} finally: controller.disconnect() if action == "emergency_stop": # This would require torque disable in real implementation - return {"status": "success", "content": [{"text": "🛑 Emergency stop executed (torque disabled)"}]} + return {"status": "success", "content": [{"text": "Emergency stop executed (torque disabled)"}]} else: return { "status": "error", "content": [ { - "text": f"❌ Unknown action: {action}\n" + "text": f"Unknown action: {action}\n" "Available actions: store_pose, load_pose, list_poses, delete_pose, show_pose, " "move_motor, move_multiple, incremental_move, read_position, read_all, " "connect, reset_to_home, emergency_stop" @@ -664,4 +664,4 @@ def pose_tool( except Exception as e: logger.error(f"Pose tool error: {e}") - return {"status": "error", "content": [{"text": f"❌ Error: {str(e)}"}]} + return {"status": "error", "content": [{"text": f"Error: {str(e)}"}]} diff --git a/strands_robots/tools/serial_tool.py b/strands_robots/tools/serial_tool.py index cfdbc47..bc2e701 100644 --- a/strands_robots/tools/serial_tool.py +++ b/strands_robots/tools/serial_tool.py @@ -93,7 +93,7 @@ def send_serial_data(ser: serial.Serial, data_to_send: str | bytes) -> None: } if not port: - return {"status": "error", "content": [{"text": "❌ Port parameter required for this action"}]} + return {"status": "error", "content": [{"text": "Port parameter required for this action"}]} # Open serial connection ser = serial.Serial(port, baudrate, timeout=timeout) @@ -103,13 +103,13 @@ def send_serial_data(ser: serial.Serial, data_to_send: str | bytes) -> None: # Parse hex string (e.g., "FF FF 01 04" -> [0xFF, 0xFF, 0x01, 0x04]) hex_bytes = bytes.fromhex(hex_data.replace(" ", "")) ser.write(hex_bytes) - response_text = f"✅ Sent hex data: {hex_data}" + response_text = f"Sent hex data: {hex_data}" elif data: ser.write(data.encode()) - response_text = f"✅ Sent string data: {data}" + response_text = f"Sent string data: {data}" else: ser.close() - return {"status": "error", "content": [{"text": "❌ No data or hex_data provided"}]} + return {"status": "error", "content": [{"text": "No data or hex_data provided"}]} ser.close() return {"status": "success", "content": [{"text": response_text}]} @@ -140,7 +140,7 @@ def send_serial_data(ser: serial.Serial, data_to_send: str | bytes) -> None: sent_text = f"Sent string: {data}" else: ser.close() - return {"status": "error", "content": [{"text": "❌ No data to send"}]} + return {"status": "error", "content": [{"text": "No data to send"}]} # Small delay then read response time.sleep(0.1) @@ -160,7 +160,7 @@ def send_serial_data(ser: serial.Serial, data_to_send: str | bytes) -> None: elif action == "feetech_position": if motor_id is None or position is None: ser.close() - return {"status": "error", "content": [{"text": "❌ motor_id and position required"}]} + return {"status": "error", "content": [{"text": "motor_id and position required"}]} # Feetech position command: INST_WRITE (0x03), Goal_Position address (0x2A) params = [0x2A, position & 0xFF, (position >> 8) & 0xFF] @@ -178,7 +178,7 @@ def send_serial_data(ser: serial.Serial, data_to_send: str | bytes) -> None: elif action == "feetech_velocity": if motor_id is None or velocity is None: ser.close() - return {"status": "error", "content": [{"text": "❌ motor_id and velocity required"}]} + return {"status": "error", "content": [{"text": "motor_id and velocity required"}]} # Feetech velocity command: Goal_Velocity address (0x2E) params = [0x2E, velocity & 0xFF, (velocity >> 8) & 0xFF] @@ -191,7 +191,7 @@ def send_serial_data(ser: serial.Serial, data_to_send: str | bytes) -> None: elif action == "feetech_ping": if motor_id is None: ser.close() - return {"status": "error", "content": [{"text": "❌ motor_id required"}]} + return {"status": "error", "content": [{"text": "motor_id required"}]} # Feetech ping command packet = build_feetech_packet(motor_id, 0x01, []) # INST_PING @@ -240,7 +240,7 @@ def send_serial_data(ser: serial.Serial, data_to_send: str | bytes) -> None: "status": "error", "content": [ { - "text": f"❌ Unknown action: {action}\n" + "text": f"Unknown action: {action}\n" "Available: list_ports, send, read, send_read," " feetech_position, feetech_velocity, feetech_ping, monitor" } @@ -248,6 +248,6 @@ def send_serial_data(ser: serial.Serial, data_to_send: str | bytes) -> None: } except serial.SerialException as e: - return {"status": "error", "content": [{"text": f"❌ Serial error: {e}"}]} + return {"status": "error", "content": [{"text": f"Serial error: {e}"}]} except Exception as e: - return {"status": "error", "content": [{"text": f"❌ Error: {e}"}]} + return {"status": "error", "content": [{"text": f"Error: {e}"}]} diff --git a/strands_robots/utils.py b/strands_robots/utils.py index f61de52..80a952a 100644 --- a/strands_robots/utils.py +++ b/strands_robots/utils.py @@ -53,9 +53,9 @@ def require_optional( raise ImportError("\n".join(parts)) from None -# ───────────────────────────────────────────────────────────────────── -# Path resolution — single source of truth for all strands-robots paths -# ───────────────────────────────────────────────────────────────────── +# +# Path resolution - single source of truth for all strands-robots paths +# #: Default base directory for all user data. DEFAULT_BASE_DIR = Path.home() / ".strands_robots" @@ -66,10 +66,10 @@ def get_base_dir() -> Path: Resolution (in priority order): - 1. ``STRANDS_BASE_DIR`` env var — explicit override. Use this when + 1. ``STRANDS_BASE_DIR`` env var - explicit override. Use this when you want to relocate *all* strands-robots user data (assets, user registry, caches) to a non-default location. - 2. ``~/.strands_robots/`` — default. + 2. ``~/.strands_robots/`` - default. Note: ``STRANDS_ASSETS_DIR`` **only** controls the assets subdirectory @@ -91,8 +91,8 @@ def get_assets_dir() -> Path: """Get the assets directory (robot model files, meshes, URDFs). Resolution: - 1. ``STRANDS_ASSETS_DIR`` env var — used as-is - 2. ``~/.strands_robots/assets/`` — default + 1. ``STRANDS_ASSETS_DIR`` env var - used as-is + 2. ``~/.strands_robots/assets/`` - default Returns: Path to the assets directory (created if needed). @@ -128,9 +128,9 @@ def resolve_asset_path(relative_or_absolute: str | Path | None, default_name: st return assets / expanded -# ───────────────────────────────────────────────────────────────────── -# Path safety — prevent traversal via untrusted components -# ───────────────────────────────────────────────────────────────────── +# +# Path safety - prevent traversal via untrusted components +# def safe_join(base: Path, untrusted: str) -> Path: @@ -166,7 +166,7 @@ def get_search_paths() -> list[Path]: """Get ordered list of asset search paths. Used by both :mod:`strands_robots.assets.manager` and - :mod:`strands_robots.assets.download` — centralised here to avoid + :mod:`strands_robots.assets.download` - centralised here to avoid a circular dependency between those two modules. Order (local assets take priority over defaults): diff --git a/tests/mocks/torch_mock.py b/tests/mocks/torch_mock.py index af124b7..5b3874b 100644 --- a/tests/mocks/torch_mock.py +++ b/tests/mocks/torch_mock.py @@ -6,9 +6,9 @@ actual GPU inference. Provides numpy-backed replacements for: -- torch.Tensor (MockTensor) — arithmetic, reshaping, device, slicing -- torch.nn.Parameter (MockParameter) — with requires_grad and device -- torch.device (MockDevice) — type string, equality, hashing +- torch.Tensor (MockTensor) - arithmetic, reshaping, device, slicing +- torch.nn.Parameter (MockParameter) - with requires_grad and device +- torch.device (MockDevice) - type string, equality, hashing - Factory functions: tensor, zeros, ones, randint, rand, from_numpy, stack, cat - Context managers: no_grad, inference_mode - Submodules: torch.nn, torch.cuda, torch.backends, torch.amp @@ -45,7 +45,7 @@ def __init__(self, data=None, dtype=None, device=None): else: self._data = np.array(data, dtype=np.float32) - # --- Properties --- + # Properties @property def shape(self): @@ -63,7 +63,7 @@ def dtype(self): def device(self): return MockDevice("cpu") - # --- Shape / size helpers --- + # Shape / size helpers def dim(self): return self._data.ndim @@ -76,7 +76,7 @@ def size(self, dim=None): def numel(self): return int(self._data.size) - # --- Conversion --- + # Conversion def item(self): return float(self._data.flat[0]) @@ -108,7 +108,7 @@ def to(self, *args, **kwargs): def contiguous(self): return self - # --- Reshaping --- + # Reshaping def unsqueeze(self, dim): return MockTensor(np.expand_dims(self._data, axis=dim)) @@ -127,7 +127,7 @@ def reshape(self, *shape): def permute(self, *dims): return MockTensor(np.transpose(self._data, dims)) - # --- Reduction --- + # Reduction def max(self): return float(self._data.max()) if self._data.size > 0 else 0.0 @@ -135,7 +135,7 @@ def max(self): def min(self): return float(self._data.min()) if self._data.size > 0 else 0.0 - # --- Dunder methods --- + # Dunder methods def __len__(self): return self._data.shape[0] if self._data.ndim > 0 else 1 @@ -231,9 +231,7 @@ def __call__(self, func): return func -# --------------------------------------------------------------------------- # Factory functions -# --------------------------------------------------------------------------- def _tensor(data, dtype=None, device=None): @@ -282,9 +280,7 @@ def _randn(*shape, dtype=None, device=None): return MockTensor(np.random.randn(*shape).astype(np.float32)) -# --------------------------------------------------------------------------- # Public API -# --------------------------------------------------------------------------- def install_torch_mock(): @@ -295,8 +291,8 @@ def install_torch_mock(): try: import torch # noqa: F401 - logger.info("Real torch is available (version=%s) — mock not installed", torch.__version__) - return # Real torch available — nothing to do + logger.info("Real torch is available (version=%s) - mock not installed", torch.__version__) + return # Real torch available - nothing to do except ImportError: pass diff --git a/tests/groot/__init__.py b/tests/policies/__init__.py similarity index 100% rename from tests/groot/__init__.py rename to tests/policies/__init__.py diff --git a/tests/policies/groot/__init__.py b/tests/policies/groot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/groot/test_client.py b/tests/policies/groot/test_client.py similarity index 93% rename from tests/groot/test_client.py rename to tests/policies/groot/test_client.py index 77cde04..44abbd8 100644 --- a/tests/groot/test_client.py +++ b/tests/policies/groot/test_client.py @@ -1,4 +1,4 @@ -"""Tests for strands_robots.policies.groot.client — ZMQ serialization and client. +"""Tests for strands_robots.policies.groot.client - ZMQ serialization and client. Covers: MsgSerializer roundtrips, Gr00tInferenceClient construction, api_token handling, and error paths. @@ -12,16 +12,16 @@ import numpy as np import pytest -msgpack = pytest.importorskip("msgpack", reason="msgpack not installed — pip install 'strands-robots[groot-service]'") -zmq = pytest.importorskip("zmq", reason="zmq not installed — pip install 'strands-robots[groot-service]'") +msgpack = pytest.importorskip("msgpack", reason="msgpack not installed - pip install 'strands-robots[groot-service]'") +zmq = pytest.importorskip("zmq", reason="zmq not installed - pip install 'strands-robots[groot-service]'") # E402: importorskip must execute before these imports to skip the module cleanly. from strands_robots.policies.groot.client import Gr00tInferenceClient, MsgSerializer # noqa: E402 from strands_robots.policies.groot.data_config import ModalityConfig # noqa: E402 -# --------------------------------------------------------------------------- +# (section) # MsgSerializer -# --------------------------------------------------------------------------- +# (section) class TestMsgSerializer: @@ -73,7 +73,7 @@ def test_decode_modality_config_n17_with_extra_fields(self): Our lightweight client-side dataclass only tracks ``delta_indices`` and ``modality_keys``. Unknown fields in the wire payload must be silently dropped so clients don't break when NVIDIA adds new metadata in future - N1.x releases. This was discovered live against GR00T-N1.7-3B — the + N1.x releases. This was discovered live against GR00T-N1.7-3B - the server sends ``sin_cos_embedding_keys``, ``mean_std_embedding_keys``, and ``action_configs`` on every response. """ @@ -129,9 +129,9 @@ def test_encode_non_custom_returns_as_is(self): assert result["num"] == 42 -# --------------------------------------------------------------------------- -# Gr00tInferenceClient — construction & api_token -# --------------------------------------------------------------------------- +# (section) +# Gr00tInferenceClient - construction & api_token +# (section) class TestGr00tInferenceClient: @@ -251,9 +251,9 @@ def test_call_endpoint_data_present_includes_data_key(self): assert sent_data[0]["data"] == {"obs": "test"} -# --------------------------------------------------------------------------- +# (section) # Dependency check -# --------------------------------------------------------------------------- +# (section) class TestZmqDeps: diff --git a/tests/groot/test_data_config.py b/tests/policies/groot/test_data_config.py similarity index 91% rename from tests/groot/test_data_config.py rename to tests/policies/groot/test_data_config.py index 51ef764..08cbe44 100644 --- a/tests/groot/test_data_config.py +++ b/tests/policies/groot/test_data_config.py @@ -1,4 +1,4 @@ -"""Tests for strands_robots.policies.groot.data_config — typed config system. +"""Tests for strands_robots.policies.groot.data_config - typed config system. Covers: Gr00tDataConfig, ModalityConfig, load_data_config, create_custom_data_config, _extends inheritance, DATA_CONFIG_MAP, and edge cases. @@ -22,9 +22,9 @@ _RAW_CONFIGS = _RAW["configs"] _RAW_ALIASES = _RAW.get("aliases", {}) -# --------------------------------------------------------------------------- +# (section) # ModalityConfig -# --------------------------------------------------------------------------- +# (section) class TestModalityConfig: @@ -47,9 +47,9 @@ def test_empty_lists(self): assert parsed["modality_keys"] == [] -# --------------------------------------------------------------------------- +# (section) # Gr00tDataConfig -# --------------------------------------------------------------------------- +# (section) class TestGr00tDataConfig: @@ -109,9 +109,9 @@ def test_modality_config_observation_indices_shared(self): assert modality_configs["action"].delta_indices == [0, 1, 2] -# --------------------------------------------------------------------------- +# (section) # DATA_CONFIG_MAP + _extends inheritance -# --------------------------------------------------------------------------- +# (section) class TestDataConfigMap: @@ -135,7 +135,7 @@ def test_aliases_resolve_correctly(self): assert DATA_CONFIG_MAP[alias_name] is DATA_CONFIG_MAP[target_name] def test_extends_inherits_parent_fields(self): - """so100_dualcam extends so100 — should inherit state/action keys.""" + """so100_dualcam extends so100 - should inherit state/action keys.""" parent = DATA_CONFIG_MAP["so100"] child = DATA_CONFIG_MAP["so100_dualcam"] assert child.video_keys == ["video.front", "video.wrist"] @@ -165,7 +165,7 @@ def test_unitree_g1_full_body_has_all_body_parts(self): assert f"action.{part}" in config.action_keys, f"Missing action.{part}" def test_unitree_g1_real_n17_schema(self): - """REAL_G1 embodiment (N1.7) — verified live from nvidia/GR00T-N1.7-3B. + """REAL_G1 embodiment (N1.7) - verified live from nvidia/GR00T-N1.7-3B. Captures the observation indices [-20, 0] (T=2 video context) and 40-step action horizon that are unique to REAL_G1. @@ -175,7 +175,7 @@ def test_unitree_g1_real_n17_schema(self): # rot6d end-effector states are the N1.7 signature assert "state.left_wrist_eef_9d" in config.state_keys assert "state.right_wrist_eef_9d" in config.state_keys - # locomotion-first action space — navigate_command is new in N1.7 + # locomotion-first action space - navigate_command is new in N1.7 assert "action.navigate_command" in config.action_keys assert "action.base_height_command" in config.action_keys # T=2 video (20 frames ago + current) and 40-step horizon @@ -213,9 +213,9 @@ def test_config_names_are_set(self): assert config.name == config_name, f"Config '{config_name}' has wrong .name: '{config.name}'" -# --------------------------------------------------------------------------- +# (section) # load_data_config -# --------------------------------------------------------------------------- +# (section) class TestLoadDataConfig: @@ -243,9 +243,9 @@ def test_load_alias(self): assert config is DATA_CONFIG_MAP[target_name] -# --------------------------------------------------------------------------- +# (section) # create_custom_data_config -# --------------------------------------------------------------------------- +# (section) class TestCreateCustomDataConfig: diff --git a/tests/groot/test_policy.py b/tests/policies/groot/test_policy.py similarity index 88% rename from tests/groot/test_policy.py rename to tests/policies/groot/test_policy.py index 4068770..e97833d 100644 --- a/tests/groot/test_policy.py +++ b/tests/policies/groot/test_policy.py @@ -1,4 +1,4 @@ -"""Tests for Gr00tPolicy — unit tests WITHOUT Isaac-GR00T installed.""" +"""Tests for Gr00tPolicy - unit tests WITHOUT Isaac-GR00T installed.""" import asyncio from unittest.mock import MagicMock, patch @@ -6,8 +6,8 @@ import numpy as np import pytest -msgpack = pytest.importorskip("msgpack", reason="msgpack not installed — pip install 'strands-robots[groot-service]'") -zmq = pytest.importorskip("zmq", reason="zmq not installed — pip install 'strands-robots[groot-service]'") +msgpack = pytest.importorskip("msgpack", reason="msgpack not installed - pip install 'strands-robots[groot-service]'") +zmq = pytest.importorskip("zmq", reason="zmq not installed - pip install 'strands-robots[groot-service]'") # All tests in this file require groot-service extras pytestmark = pytest.mark.skipif( @@ -28,9 +28,9 @@ _to_video_batch, ) -# --------------------------------------------------------------------------- +# (section) # Helpers -# --------------------------------------------------------------------------- +# (section) _KNOWN_DOF = { "single_arm": 5, @@ -93,9 +93,9 @@ def _make_policy(data_config="so100", version="n1.6", obs_mapping=None, action_m return p -# --------------------------------------------------------------------------- +# (section) # Construction -# --------------------------------------------------------------------------- +# (section) class TestConstruction: @@ -156,7 +156,7 @@ def test_all_configs(self): assert Gr00tPolicy(data_config=name)._mode == "service" def test_no_denoising_steps_param(self): - """denoising_steps was removed from __init__ — kwargs swallows it.""" + """denoising_steps was removed from __init__ - kwargs swallows it.""" p = Gr00tPolicy(denoising_steps=8) assert p._mode == "service" # no error, just ignored via **kwargs @@ -166,9 +166,9 @@ def test_set_robot_state_keys_is_noop(self): p.set_robot_state_keys(["a", "b"]) # should not raise -# --------------------------------------------------------------------------- +# (section) # Version detection -# --------------------------------------------------------------------------- +# (section) class TestVersion: @@ -211,7 +211,7 @@ def test_force_false_uses_cache(self): def test_detect_n17(self): """N1.7 is detected when the ``gr00t.model.gr00t_n1d7`` subpackage exists. - N1.6 and N1.7 share ``gr00t.policy.gr00t_policy`` — so we need a + N1.6 and N1.7 share ``gr00t.policy.gr00t_policy`` - so we need a version-specific probe. ``gr00t_n1d7`` was introduced in N1.7. """ import strands_robots.policies.groot.policy as pm @@ -261,7 +261,7 @@ def test_detect_order_prefers_n17(self): orig = pm._GROOT_VERSION pm._GROOT_VERSION = None try: - # All three probes return a spec—N1.7 must come first. + # All three probes return a spec - N1.7 must come first. with patch("importlib.util.find_spec", return_value=MagicMock()): assert _detect_groot_version(force=True) == "n1.7" finally: @@ -286,9 +286,9 @@ def fake_find_spec(name: str): pm._GROOT_VERSION = orig -# --------------------------------------------------------------------------- +# (section) # ObservationMapping -# --------------------------------------------------------------------------- +# (section) class TestObsMapping: @@ -320,9 +320,9 @@ def test_bad_lang(self): ObservationMapping(language_key="nope").validate(GR1_MMC) -# --------------------------------------------------------------------------- +# (section) # ActionMapping -# --------------------------------------------------------------------------- +# (section) class TestActionMapping: @@ -337,9 +337,9 @@ def test_bad(self): ActionMapping(actions={"nope": "j"}).validate(GR1_MMC) -# --------------------------------------------------------------------------- +# (section) # Parsing -# --------------------------------------------------------------------------- +# (section) class TestParsing: @@ -366,9 +366,9 @@ def test_action(self): assert m.actions == {"left_arm": "j", "left_hand": "g"} -# --------------------------------------------------------------------------- +# (section) # Auto-inference -# --------------------------------------------------------------------------- +# (section) class TestAutoInfer: @@ -386,9 +386,9 @@ def test_action_exact(self): assert m.actions["single_arm"] == "single_arm" -# --------------------------------------------------------------------------- +# (section) # Shape helpers -# --------------------------------------------------------------------------- +# (section) class TestShapes: @@ -416,7 +416,7 @@ def test_ref_from_mapped_video_keys(self): """Should only look at keys in the video_keys set.""" obs = { "cam": np.zeros((128, 128, 3)), - "state_3d": np.zeros((10, 10, 3)), # 3D state — should NOT match + "state_3d": np.zeros((10, 10, 3)), # 3D state - should NOT match } assert _reference_video_shape(obs, video_keys={"cam"}) == (128, 128, 3) @@ -434,9 +434,9 @@ def test_ref_legacy_heuristic_when_no_video_keys(self): assert _reference_video_shape(obs, video_keys=None) == (128, 128, 3) -# --------------------------------------------------------------------------- -# _prepare_observation — nested dict format -# --------------------------------------------------------------------------- +# (section) +# _prepare_observation - nested dict format +# (section) class TestPrepareObs: @@ -484,7 +484,7 @@ def test_skips_zero_fill_unknown_dof(self): video={"cam": "webcam"}, state={"arm": "single_arm"}, language_key="annotation.human.task_description" ), ) - # Clear DOF for gripper — simulate unknown + # Clear DOF for gripper - simulate unknown p._model_state_dof = {"single_arm": 5} b = p._prepare_observation({"cam": np.zeros((64, 64, 3), dtype=np.uint8), "arm": np.zeros(5)}, "t") # gripper DOF unknown → should NOT be in state dict @@ -492,9 +492,9 @@ def test_skips_zero_fill_unknown_dof(self): assert "single_arm" in b["state"] -# --------------------------------------------------------------------------- +# (section) # _unpack_actions -# --------------------------------------------------------------------------- +# (section) class TestUnpackActions: @@ -512,9 +512,9 @@ def test_empty(self): assert _make_policy(action_mapping=ActionMapping())._unpack_actions({}) == [] -# --------------------------------------------------------------------------- +# (section) # Full local flow -# --------------------------------------------------------------------------- +# (section) class TestLocalFlow: @@ -567,9 +567,9 @@ def test_bad_version(self): p._local_get_actions({}, "t") -# --------------------------------------------------------------------------- +# (section) # get_actions routing -# --------------------------------------------------------------------------- +# (section) class TestGetActions: @@ -598,9 +598,9 @@ def test_service(self): assert len(acts) == 16 -# --------------------------------------------------------------------------- +# (section) # Service observation + action unpack -# --------------------------------------------------------------------------- +# (section) class TestServiceObs: @@ -658,9 +658,9 @@ def test_empty_mapping(self): assert "single_arm" in result[0] -# --------------------------------------------------------------------------- +# (section) # Exports -# --------------------------------------------------------------------------- +# (section) class TestExports: diff --git a/tests/policies/lerobot_local/__init__.py b/tests/policies/lerobot_local/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_lerobot_local.py b/tests/policies/lerobot_local/test_policy.py similarity index 94% rename from tests/test_lerobot_local.py rename to tests/policies/lerobot_local/test_policy.py index 00c56c7..611a592 100644 --- a/tests/test_lerobot_local.py +++ b/tests/policies/lerobot_local/test_policy.py @@ -1,4 +1,4 @@ -"""Tests for strands_robots.policies.lerobot_local — LerobotLocalPolicy. +"""Tests for strands_robots.policies.lerobot_local - LerobotLocalPolicy. All tests run WITHOUT lerobot installed (pure mock/unit testing). """ @@ -9,7 +9,7 @@ import numpy as np import pytest -import torch # real or conftest mock — both work +import torch # real or conftest mock - both work from strands_robots.policies import create_policy from strands_robots.policies.lerobot_local.policy import LerobotLocalPolicy @@ -21,9 +21,9 @@ ) from strands_robots.registry import list_policy_providers -# --------------------------------------------------------------------------- +# (section) # Helpers -# --------------------------------------------------------------------------- +# (section) def _make_policy(**kwargs): @@ -72,9 +72,9 @@ def _make_loaded_policy(action_dim=6, state_dim=6, device="cpu", include_images= return policy -# --------------------------------------------------------------------------- +# (section) # Tests: Initialization -# --------------------------------------------------------------------------- +# (section) class TestLerobotLocalInit: @@ -98,9 +98,9 @@ def test_custom_actions_per_step(self): assert policy.actions_per_step == 5 -# --------------------------------------------------------------------------- +# (section) # Tests: set_robot_state_keys -# --------------------------------------------------------------------------- +# (section) class TestSetRobotStateKeys: @@ -133,9 +133,9 @@ def test_empty_keys_no_features_raises(self): policy.set_robot_state_keys([]) -# --------------------------------------------------------------------------- +# (section) # Tests: Tokenizer resolution (VLA support) -# --------------------------------------------------------------------------- +# (section) class TestResolveTokenizer: @@ -260,9 +260,9 @@ def test_no_language_indicators_returns_false(self): assert policy._needs_language_tokens() is False -# --------------------------------------------------------------------------- +# (section) # Tests: _load_model -# --------------------------------------------------------------------------- +# (section) class TestLoadModel: @@ -375,9 +375,9 @@ def test_auto_generates_state_keys_from_output(self): assert policy.robot_state_keys == ["joint_0", "joint_1", "joint_2", "joint_3"] -# --------------------------------------------------------------------------- +# (section) # Tests: get_actions (async) -# --------------------------------------------------------------------------- +# (section) class TestGetActions: @@ -491,9 +491,9 @@ def test_processor_bridge_postprocess_applied(self): assert actions[0]["b"] == 20.0 -# --------------------------------------------------------------------------- +# (section) # Tests: _build_observation_batch -# --------------------------------------------------------------------------- +# (section) class TestBuildObservationBatch: @@ -560,9 +560,9 @@ def test_float64_numpy_auto_cast_to_float32(self): assert batch["observation.state"].dtype == torch.float32 -# --------------------------------------------------------------------------- +# (section) # Tests: _build_batch_from_strands_format -# --------------------------------------------------------------------------- +# (section) class TestBuildBatchFromStrandsFormat: @@ -594,9 +594,9 @@ def test_empty_state_keys_raises(self): policy._build_batch_from_strands_format({"x": 1.0}, {}) -# --------------------------------------------------------------------------- +# (section) # Tests: _tensor_to_action_dicts -# --------------------------------------------------------------------------- +# (section) class TestTensorToActionDicts: @@ -624,9 +624,9 @@ def test_empty_state_keys_raises(self): policy._tensor_to_action_dicts(torch.tensor([1.0, 2.0])) -# --------------------------------------------------------------------------- +# (section) # Tests: reset -# --------------------------------------------------------------------------- +# (section) class TestReset: @@ -644,9 +644,9 @@ def test_reset_safe_when_not_loaded(self): policy.reset() # Should not raise -# --------------------------------------------------------------------------- +# (section) # Tests: Policy resolution helpers -# --------------------------------------------------------------------------- +# (section) class TestPolicyResolution: @@ -680,9 +680,9 @@ def test_read_policy_type_from_local_config(self, tmp_path): assert result == "act" -# --------------------------------------------------------------------------- +# (section) # Tests: Registry integration -# --------------------------------------------------------------------------- +# (section) class TestRegistryIntegration: @@ -698,9 +698,9 @@ def test_create_policy_lerobot_local_without_model(self, monkeypatch): assert policy._loaded is False -# --------------------------------------------------------------------------- +# (section) # Tests: ProcessorBridge -# --------------------------------------------------------------------------- +# (section) class TestProcessorBridge: diff --git a/tests/policies/test_base.py b/tests/policies/test_base.py new file mode 100644 index 0000000..f08fea3 --- /dev/null +++ b/tests/policies/test_base.py @@ -0,0 +1,57 @@ +"""Tests for ``strands_robots.policies.base.Policy`` ABC contract. + +Covers the ``get_actions_sync`` event-loop dispatch paths: the 'no loop' +fast path and the 'already-in-event-loop' ThreadPoolExecutor fallback. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from strands_robots.policies.base import Policy + + +class _IdentityPolicy(Policy): + """Minimal concrete Policy for testing Policy ABC's sync wrapper.""" + + def __init__(self) -> None: + self._keys = ["j0"] + + async def get_actions( + self, observation_dict: dict[str, Any], instruction: str, **kwargs: Any + ) -> list[dict[str, Any]]: + return [{"j0": 0.1}, {"j0": 0.2}] + + def set_robot_state_keys(self, robot_state_keys: list[str]) -> None: + self._keys = list(robot_state_keys) + + @property + def provider_name(self) -> str: + return "identity" + + +def test_get_actions_sync_outside_event_loop_uses_asyncio_run(): + p = _IdentityPolicy() + actions = p.get_actions_sync({"observation.state": [0.0]}, instruction="hi") + assert actions == [{"j0": 0.1}, {"j0": 0.2}] + + +def test_get_actions_sync_inside_event_loop_uses_threadpool(): + """When called from within a running event loop, the sync wrapper must + off-load to a thread pool instead of raising 'already in a loop'.""" + p = _IdentityPolicy() + + async def inner(): + # Calling the sync wrapper here forces the thread-pool branch + return p.get_actions_sync({"observation.state": [0.0]}, instruction="hi") + + actions = asyncio.run(inner()) + assert actions == [{"j0": 0.1}, {"j0": 0.2}] + + +def test_provider_name_and_state_keys(): + p = _IdentityPolicy() + assert p.provider_name == "identity" + p.set_robot_state_keys(["a", "b", "c"]) + assert p._keys == ["a", "b", "c"] diff --git a/tests/test_policies.py b/tests/policies/test_factory.py similarity index 72% rename from tests/test_policies.py rename to tests/policies/test_factory.py index 40ccf90..e1ca90b 100644 --- a/tests/test_policies.py +++ b/tests/policies/test_factory.py @@ -1,6 +1,9 @@ -"""Tests for strands_robots.policies — behavior-focused tests for the policy system.""" +"""Tests for ``strands_robots.policies.factory.create_policy``. -import asyncio +* provider resolution (mock / groot / lerobot_local) +* ``trust_remote_code`` security gate for HF-backed providers +* kwargs forwarding to the chosen provider +""" import pytest @@ -23,63 +26,6 @@ _groot_available = False -class TestMockPolicy: - """MockPolicy should produce deterministic sinusoidal trajectories.""" - - def test_full_lifecycle(self): - """Create -> set keys -> get actions -> verify structure and determinism.""" - p = create_policy("mock") - assert isinstance(p, MockPolicy) - assert p.provider_name == "mock" - - p.set_robot_state_keys(["j0", "j1", "j2"]) - - obs = {"observation.state": [0.0, 0.0, 0.0]} - actions = asyncio.run(p.get_actions(obs, "pick up the block")) - - # 8-step horizon, each action has all 3 keys - assert len(actions) == 8 - assert set(actions[0].keys()) == {"j0", "j1", "j2"} - - # Deterministic - p2 = MockPolicy() - p2.set_robot_state_keys(["j0", "j1", "j2"]) - actions2 = asyncio.run(p2.get_actions(obs, "different instruction")) - assert actions == actions2 - - def test_auto_generates_keys_from_observation(self): - """When no keys are set, infers dimensionality from observation.state.""" - p = MockPolicy() - obs = {"observation.state": [0.0] * 7} - actions = p.get_actions_sync(obs, "test") - assert len(actions[0]) == 7 - assert "joint_0" in actions[0] and "joint_6" in actions[0] - - def test_defaults_to_6dof(self): - """With empty observation, defaults to 6-DOF.""" - p = MockPolicy() - actions = p.get_actions_sync({}, "test") - assert len(actions[0]) == 6 - - def test_values_are_bounded_sinusoids(self): - """All action values should stay within +/-0.6.""" - p = MockPolicy() - p.set_robot_state_keys(["j0", "j1"]) - for _ in range(10): - actions = p.get_actions_sync({"observation.state": [0, 0]}, "test") - for a in actions: - for v in a.values(): - assert -0.6 <= v <= 0.6, f"Value {v} out of bounds" - - def test_get_actions_sync_works_from_sync_context(self): - """get_actions_sync() should be usable from plain synchronous code.""" - p = MockPolicy() - p.set_robot_state_keys(["a", "b"]) - actions = p.get_actions_sync({"observation.state": [0, 0]}, "move") - assert len(actions) == 8 - assert all(isinstance(a, dict) for a in actions) - - class TestCreatePolicy: """create_policy() should resolve shorthands, URLs, and custom registrations.""" diff --git a/tests/policies/test_mock.py b/tests/policies/test_mock.py new file mode 100644 index 0000000..c9e14bf --- /dev/null +++ b/tests/policies/test_mock.py @@ -0,0 +1,79 @@ +"""Tests for ``strands_robots.policies.mock.MockPolicy``. + +MockPolicy is the only non-ML policy provider - it generates smooth +sinusoidal actions and is the workhorse for every policy-runner / recording / +evaluate test in the suite. +""" + +import asyncio + +from strands_robots.policies import ( + MockPolicy, + create_policy, +) + +# Detect groot-service availability for conditional test grouping. +try: + import msgpack # noqa: F401 + import zmq # noqa: F401 + + _groot_available = True +except ImportError: + _groot_available = False + + +class TestMockPolicy: + """MockPolicy should produce deterministic sinusoidal trajectories.""" + + def test_full_lifecycle(self): + """Create -> set keys -> get actions -> verify structure and determinism.""" + p = create_policy("mock") + assert isinstance(p, MockPolicy) + assert p.provider_name == "mock" + + p.set_robot_state_keys(["j0", "j1", "j2"]) + + obs = {"observation.state": [0.0, 0.0, 0.0]} + actions = asyncio.run(p.get_actions(obs, "pick up the block")) + + # 8-step horizon, each action has all 3 keys + assert len(actions) == 8 + assert set(actions[0].keys()) == {"j0", "j1", "j2"} + + # Deterministic + p2 = MockPolicy() + p2.set_robot_state_keys(["j0", "j1", "j2"]) + actions2 = asyncio.run(p2.get_actions(obs, "different instruction")) + assert actions == actions2 + + def test_auto_generates_keys_from_observation(self): + """When no keys are set, infers dimensionality from observation.state.""" + p = MockPolicy() + obs = {"observation.state": [0.0] * 7} + actions = p.get_actions_sync(obs, "test") + assert len(actions[0]) == 7 + assert "joint_0" in actions[0] and "joint_6" in actions[0] + + def test_defaults_to_6dof(self): + """With empty observation, defaults to 6-DOF.""" + p = MockPolicy() + actions = p.get_actions_sync({}, "test") + assert len(actions[0]) == 6 + + def test_values_are_bounded_sinusoids(self): + """All action values should stay within +/-0.6.""" + p = MockPolicy() + p.set_robot_state_keys(["j0", "j1"]) + for _ in range(10): + actions = p.get_actions_sync({"observation.state": [0, 0]}, "test") + for a in actions: + for v in a.values(): + assert -0.6 <= v <= 0.6, f"Value {v} out of bounds" + + def test_get_actions_sync_works_from_sync_context(self): + """get_actions_sync() should be usable from plain synchronous code.""" + p = MockPolicy() + p.set_robot_state_keys(["a", "b"]) + actions = p.get_actions_sync({"observation.state": [0, 0]}, "move") + assert len(actions) == 8 + assert all(isinstance(a, dict) for a in actions) diff --git a/tests/registry/__init__.py b/tests/registry/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/registry/test_format_robot_table.py b/tests/registry/test_format_robot_table.py new file mode 100644 index 0000000..6326d5e --- /dev/null +++ b/tests/registry/test_format_robot_table.py @@ -0,0 +1,74 @@ +"""Tests for ``format_robot_table`` - column width handling (issue #113).""" + +from __future__ import annotations + +from strands_robots.registry.robots import ( + _FIXED_PREFIX_WIDTH, + format_robot_table, + list_robots, +) + + +class TestDefaultWidth: + def test_default_max_line_length_is_bounded(self): + table = format_robot_table() # default max_width=100 + max_len = max(len(line) for line in table.split("\n")) + # Allow a small margin - the rule is the longest line; data rows + # should fit inside max_width + some padding for the header/rule. + assert max_len <= 101, f"max line {max_len} exceeds 100 chars" + + def test_contains_header_and_total(self): + table = format_robot_table() + assert "Name" in table + assert "Category" in table + assert "Description" in table + assert f"Total: {len(list_robots())} robots" in table + + def test_contains_all_categories(self): + table = format_robot_table() + # At least one of each category should be represented in the registry. + for cat in ("arm", "humanoid", "hand"): + assert cat in table + + +class TestNarrowWidth: + def test_80_col_terminal_fits(self): + table = format_robot_table(max_width=80) + max_len = max(len(line) for line in table.split("\n")) + # 80 is a hard target for narrow terminals; our rule is <= that + 1 + # (the ellipsis adds one wide char that may not be counted). + assert max_len <= 81, f"max line {max_len} exceeds 80 chars" + + def test_descriptions_are_truncated_with_ellipsis(self): + """Long descriptions should end with the truncation marker '...'.""" + narrow = format_robot_table(max_width=80) + wide = format_robot_table(max_width=1000) + # At least one row must have been truncated at narrow width. + assert "..." in narrow + # And that same row is longer in the wide rendering. + assert "..." not in wide + + +class TestWideWidth: + def test_wide_width_disables_truncation(self): + table = format_robot_table(max_width=1000) + assert "..." not in table + + def test_minimum_desc_width_is_enforced(self): + """Even at absurdly narrow widths we keep a 20-char Description column + rather than collapsing to zero.""" + table = format_robot_table(max_width=20) + # Prefix alone is wider than 20; we clamp to + # _FIXED_PREFIX_WIDTH + 20 so every row still shows some description. + max_len = max(len(line) for line in table.split("\n")) + assert max_len >= _FIXED_PREFIX_WIDTH + 20 - 1 + + +class TestConsistency: + def test_row_count_matches_registry(self): + """The table should have (2 header + robots + 2 footer) lines. + Categories with zero robots contribute no data rows.""" + table = format_robot_table() + lines = table.split("\n") + non_empty_rows = [line for line in lines[2:-2] if line.strip() and "Total:" not in line] + assert len(non_empty_rows) == len(list_robots()) diff --git a/tests/test_registry_integrity.py b/tests/registry/test_integrity.py similarity index 87% rename from tests/test_registry_integrity.py rename to tests/registry/test_integrity.py index 7667631..889d733 100644 --- a/tests/test_registry_integrity.py +++ b/tests/registry/test_integrity.py @@ -1,4 +1,4 @@ -"""Registry integrity tests — catch silent regressions in robots.json. +"""Registry integrity tests - catch silent regressions in robots.json. These tests enforce invariants on the robot registry that prevent classes of bugs like the one flagged by @awsarron on PR #84 review (2026-04-21): @@ -13,7 +13,7 @@ import pytest -REGISTRY_PATH = Path(__file__).parent.parent / "strands_robots" / "registry" / "robots.json" +REGISTRY_PATH = Path(__file__).resolve().parents[2] / "strands_robots" / "registry" / "robots.json" @pytest.fixture(scope="module") @@ -33,9 +33,9 @@ def test_every_robot_declares_auto_download_strategy(registry: dict) -> None: """Every robot with an ``asset`` block must declare HOW it gets auto-downloaded. Valid options (exactly one required): - 1. ``asset.robot_descriptions_module`` — the robot_descriptions pip module name. - 2. ``asset.source`` with ``type: "github"`` — custom GitHub source block. - 3. ``asset.auto_download: false`` — explicit opt-out (user must supply assets). + 1. ``asset.robot_descriptions_module`` - the robot_descriptions pip module name. + 2. ``asset.source`` with ``type: "github"`` - custom GitHub source block. + 3. ``asset.auto_download: false`` - explicit opt-out (user must supply assets). Without one of these, auto-download silently falls through to the naming-convention heuristic, which fails for most robots and only @@ -45,7 +45,7 @@ def test_every_robot_declares_auto_download_strategy(registry: dict) -> None: for name, info in registry.items(): asset = info.get("asset") if not asset: - continue # No asset block — nothing to auto-download. + continue # No asset block - nothing to auto-download. has_rd = "robot_descriptions_module" in asset has_source = isinstance(asset.get("source"), dict) and asset["source"].get("type") == "github" @@ -103,7 +103,7 @@ def _collect_aliases(registry: dict) -> dict[str, str]: def test_aliases_unique_across_registry(registry: dict) -> None: - """No two robots may declare the same alias — last-loaded would silently win.""" + """No two robots may declare the same alias - last-loaded would silently win.""" seen: dict[str, str] = {} collisions: list[str] = [] for name, info in registry.items(): @@ -118,7 +118,7 @@ def test_no_alias_shadows_canonical_name(registry: dict) -> None: """An alias must not equal the canonical name of another robot. Shadowing causes resolution order to silently determine the winner, which - is fragile — a future reorder of robots.json could flip which robot a + is fragile - a future reorder of robots.json could flip which robot a name resolves to. """ canonical = _all_canonical_names(registry) @@ -133,7 +133,7 @@ def test_no_alias_shadows_canonical_name(registry: dict) -> None: def test_hardware_only_robots_declare_lerobot_type(registry: dict) -> None: """Robots without an ``asset`` block must still declare a LeRobot hardware type. - Prevents silent typos in ``hardware.lerobot_type`` — catches a misspelled + Prevents silent typos in ``hardware.lerobot_type`` - catches a misspelled type during registry expansion rather than at teleop time. """ offenders: list[str] = [] diff --git a/tests/test_registry.py b/tests/registry/test_public_api.py similarity index 92% rename from tests/test_registry.py rename to tests/registry/test_public_api.py index d1775c7..23d37ee 100644 --- a/tests/test_registry.py +++ b/tests/registry/test_public_api.py @@ -1,4 +1,4 @@ -"""Tests for strands_robots.registry — tests for loader, policies, and robots modules.""" +"""Tests for strands_robots.registry - tests for loader, policies, and robots modules.""" import pytest @@ -17,11 +17,11 @@ resolve_name, ) -# ─── Loader tests ───────────────────────────────────────────────────── +# Loader tests class TestLoader: - """loader.py — JSON loading, caching, hot-reload, and validation.""" + """loader.py - JSON loading, caching, hot-reload, and validation.""" def test_load_caches_and_returns_same_object(self): """Consecutive loads without file change should return cached data.""" @@ -115,7 +115,7 @@ def test_validate_clean_data_passes(self): _validate("policies", clean_policies) -# ─── Policy resolution tests ────────────────────────────────────────── +# Policy resolution tests class TestResolvePolicy: @@ -192,7 +192,7 @@ def test_case_insensitive_shorthand(self): assert provider == "groot" -# ─── Provider lookup tests ──────────────────────────────────────────── +# Provider lookup tests class TestProviderLookup: @@ -228,7 +228,7 @@ def test_get_provider_by_alias(self): assert config["class"] == "MockPolicy" -# ─── import_policy_class tests ──────────────────────────────────────── +# import_policy_class tests class TestImportPolicyClass: @@ -254,7 +254,7 @@ def test_import_via_alias(self): assert cls is MockPolicy -# ─── build_policy_kwargs tests ──────────────────────────────────────── +# build_policy_kwargs tests class TestBuildPolicyKwargs: @@ -298,11 +298,11 @@ def test_groot_only_port_no_host_gets_default(self): assert kwargs["host"] == "localhost" # from defaults -# ─── Robot registry tests ───────────────────────────────────────────── +# Robot registry tests class TestRobotRegistry: - """robots.py — resolve, query, filter, and format robot definitions.""" + """robots.py - resolve, query, filter, and format robot definitions.""" def test_resolve_name_canonical(self): assert resolve_name("so100") == "so100" diff --git a/tests/test_registry_resolves.py b/tests/registry/test_resolves.py similarity index 74% rename from tests/test_registry_resolves.py rename to tests/registry/test_resolves.py index 122e42e..324521f 100644 --- a/tests/test_registry_resolves.py +++ b/tests/registry/test_resolves.py @@ -8,7 +8,7 @@ - Missing ``dir`` or ``model_xml`` keys in sim-capable robots - Path traversal sequences in registry entries -The test does NOT require downloaded assets or GPU — it only validates the +The test does NOT require downloaded assets or GPU - it only validates the registry metadata itself (directory/file names, path safety). Run it in the unit or integ hatch env. @@ -20,11 +20,11 @@ import pytest -# ───────────────────────────────────────────────────────────────────── +# # Load registry directly to avoid import side effects -# ───────────────────────────────────────────────────────────────────── +# -_REGISTRY_PATH = Path(__file__).resolve().parent.parent / "strands_robots" / "registry" / "robots.json" +_REGISTRY_PATH = Path(__file__).resolve().parents[2] / "strands_robots" / "registry" / "robots.json" def _load_registry() -> dict: @@ -42,9 +42,9 @@ def _load_registry() -> dict: _SIM_ROBOT_NAMES = list(_SIM_ROBOTS.keys()) -# ───────────────────────────────────────────────────────────────────── +# # Tests for ALL robots (sim + hardware-only) -# ───────────────────────────────────────────────────────────────────── +# @pytest.mark.parametrize("name", list(_ROBOTS.keys()), ids=list(_ROBOTS.keys())) @@ -67,9 +67,9 @@ def test_registry_resolve_via_api(name: str) -> None: assert info is not None, f"get_robot({name!r}) returned None" -# ───────────────────────────────────────────────────────────────────── +# # Tests for sim-capable robots only (have 'asset' key) -# ───────────────────────────────────────────────────────────────────── +# @pytest.mark.parametrize("name", _SIM_ROBOT_NAMES, ids=_SIM_ROBOT_NAMES) diff --git a/tests/test_user_registry.py b/tests/registry/test_user_registry.py similarity index 97% rename from tests/test_user_registry.py rename to tests/registry/test_user_registry.py index 66e5690..cf97305 100644 --- a/tests/test_user_registry.py +++ b/tests/registry/test_user_registry.py @@ -26,9 +26,9 @@ ) from strands_robots.utils import get_assets_dir, get_base_dir, resolve_asset_path -# --------------------------------------------------------------------------- +# (section) # Helpers -# --------------------------------------------------------------------------- +# (section) _MINIMAL_MJCF = '' @@ -39,7 +39,7 @@ def _isolate_registry(tmp_path, monkeypatch): ``STRANDS_BASE_DIR`` controls where ``user_robots.json`` lives. ``STRANDS_ASSETS_DIR`` controls where robot asset directories live. - The two are independent — the base dir is not derived from the assets dir. + The two are independent - the base dir is not derived from the assets dir. """ assets_dir = tmp_path / "assets" assets_dir.mkdir() @@ -301,7 +301,7 @@ def test_import_error_returns_data_unchanged(self): class TestStrandsBaseDirIntegration: """Registry file location respects STRANDS_BASE_DIR env var. - STRANDS_ASSETS_DIR intentionally does NOT move the registry — it only + STRANDS_ASSETS_DIR intentionally does NOT move the registry - it only controls where asset directories live. See utils.get_base_dir() docstring. """ @@ -350,7 +350,7 @@ def test_custom(self, tmp_path, monkeypatch): class TestGetBaseDir: """get_base_dir() returns STRANDS_BASE_DIR or ~/.strands_robots/. - It is independent of STRANDS_ASSETS_DIR by design — the base dir holds + It is independent of STRANDS_ASSETS_DIR by design - the base dir holds user metadata (user_robots.json) and should not move just because the user repoints the asset cache. """ diff --git a/tests/simulation/__init__.py b/tests/simulation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/simulation/mujoco/__init__.py b/tests/simulation/mujoco/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/simulation/mujoco/test_agenttool_contract.py b/tests/simulation/mujoco/test_agenttool_contract.py new file mode 100644 index 0000000..7020a84 --- /dev/null +++ b/tests/simulation/mujoco/test_agenttool_contract.py @@ -0,0 +1,607 @@ +"""T1/T13: AgentTool router contract - unknown kwargs rejected, required args friendly, +vector dims validated, tool_spec matches method signatures.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from strands_robots.simulation.mujoco.simulation import Simulation + + +@pytest.fixture +def sim(): + s = Simulation(tool_name="contract_test", mesh=False) + s.create_world() + yield s + s.cleanup() + + +class TestRouterRejectsUnknownKwargs: + """T1 DoD: Unknown top-level params must be rejected with a clear message.""" + + def test_unknown_kwarg_on_set_gravity(self, sim): + result = sim._dispatch_action("set_gravity", {"gravity": [0, 0, -9.81], "bogus_param": 42}) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "Unknown parameter 'bogus_param'" in text + assert "set_gravity" in text + assert "Valid:" in text + + def test_unknown_kwarg_on_step(self, sim): + result = sim._dispatch_action("step", {"n_steps": 5, "num_steps": 10}) + assert result["status"] == "error" + assert "Unknown parameter 'num_steps'" in result["content"][0]["text"] + + def test_unknown_kwarg_on_reset(self, sim): + result = sim._dispatch_action("reset", {"hard_reset": True}) + assert result["status"] == "error" + assert "Unknown parameter 'hard_reset'" in result["content"][0]["text"] + + +class TestRouterRequiredArgError: + """T1 DoD: Missing required params produce a friendly error (no Python TypeError).""" + + def test_missing_required_arg_on_add_object(self, sim): + # add_object requires `name`. Default for shape is `box` but `name` has no default. + result = sim._dispatch_action("add_object", {"shape": "box"}) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "requires parameter 'name'" in text + assert "add_object" in text + + def test_missing_required_arg_on_stop_policy(self, sim): + # stop_policy has robot_name default="" so it's not technically required; + # but apply_force requires body_name. + result = sim._dispatch_action("apply_force", {"force": [0, 0, 1]}) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "requires parameter 'body_name'" in text + + +class TestRouterValidatesVectorDims: + """T1 DoD: Vector params with wrong length rejected before reaching MuJoCo.""" + + def test_gravity_wrong_length_rejected(self, sim): + result = sim._dispatch_action("set_gravity", {"gravity": [0, 0]}) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "'gravity'" in text and "3" in text and "2" in text + + def test_position_wrong_length_rejected(self, sim): + result = sim._dispatch_action( + "add_object", + {"name": "box1", "shape": "box", "position": [0, 0]}, + ) + assert result["status"] == "error" + assert "'position'" in result["content"][0]["text"] + + def test_orientation_wrong_length_rejected(self, sim): + # orientation is a quaternion (4) + result = sim._dispatch_action( + "add_object", + {"name": "box1", "shape": "box", "orientation": [1, 0, 0]}, + ) + assert result["status"] == "error" + assert "'orientation'" in result["content"][0]["text"] + + def test_color_wrong_length_rejected(self, sim): + # color is rgba (4) + result = sim._dispatch_action( + "add_object", + {"name": "box1", "shape": "box", "color": [1, 0, 0]}, + ) + assert result["status"] == "error" + assert "'color'" in result["content"][0]["text"] + + def test_non_numeric_vector_component_rejected(self, sim): + result = sim._dispatch_action("set_gravity", {"gravity": [0, 0, "low"]}) + assert result["status"] == "error" + assert "numeric" in result["content"][0]["text"].lower() + + def test_non_list_vector_rejected(self, sim): + result = sim._dispatch_action("set_gravity", {"gravity": 9.81}) + assert result["status"] == "error" + assert "'gravity'" in result["content"][0]["text"] + + +class TestRouterKwargsPassthrough: + """Methods with **kwargs in signature accept unknown params without error.""" + + def test_add_object_accepts_extra_kwargs(self, sim): + # add_object has **kwargs so extra params are allowed (backwards compat). + result = sim._dispatch_action( + "add_object", + {"name": "box1", "shape": "box", "future_flag": True}, + ) + # Either success (extra key ignored) or a proper runtime error; must NOT + # be an "unknown parameter" router rejection. + if result["status"] == "error": + assert "Unknown parameter" not in result["content"][0]["text"] + + +class TestToolSpecMethodParity: + """T13 DoD: every enum action in tool_spec.json has a matching method whose + signature matches declared top-level params.""" + + # Params in tool_spec.json that are intentionally not consumed by every method + # (they are cross-cutting or action-conditional). + SPEC_ONLY_ALLOWED = { + # action is the dispatch key itself + "action", + # video composite params - folded into `video` by the router + "output_path", + "fps", + # name/robot_name are aliased bi-directionally + "robot_name", + "name", + # global knobs sometimes listed at top level for LLM convenience + } + + def test_every_action_maps_to_a_method(self, sim): + import strands_robots.simulation.mujoco as _mj_mod + + spec_path = Path(_mj_mod.__file__).parent / "tool_spec.json" + spec = json.loads(spec_path.read_text()) + actions = spec["properties"]["action"]["enum"] + + missing = [] + for action in actions: + method_name = sim._ACTION_ALIASES.get(action, action) + if not hasattr(sim, method_name): + missing.append(action) + assert not missing, f"Actions without a method: {missing}" + + def test_no_method_has_silently_unused_param(self, sim): + """Known legacy drifts that the router USED to silently drop are now + either implemented or flagged by the router. This test enumerates + the pre-T1 drift cases as a regression ward.""" + # Before T1: step(num_steps), run_policy(n_steps wrong), etc. silently dropped. + # After T1: all of these rejected. Verify a sampling. + drift_cases = [ + ("step", {"num_steps": 5}), # should be `n_steps` + ("forward_kinematics", {"some_ghost_param": 1}), + ("get_features", {"unknown_filter": "a"}), + ] + for action, bad_kwargs in drift_cases: + result = sim._dispatch_action(action, bad_kwargs) + # Router must reject; must NOT silently succeed with default values. + assert result["status"] == "error", f"{action} silently accepted {bad_kwargs}" + + +class TestUnifiedNoWorldMessage: + """T14: Every action must use the same 'No world.' message when no world exists.""" + + @pytest.fixture + def fresh_sim(self): + """A sim with NO world.""" + s = Simulation(tool_name="no_world_test", mesh=False) + yield s + s.cleanup() + + def _assert_standard_no_world_error(self, result, action): + assert result["status"] == "error", f"{action} should error when no world" + text = result["content"][0]["text"] + assert "No world" in text, f"{action} error text lacks 'No world': {text}" + + def test_step_no_world(self, fresh_sim): + self._assert_standard_no_world_error(fresh_sim._dispatch_action("step", {"n_steps": 1}), "step") + + def test_reset_no_world(self, fresh_sim): + self._assert_standard_no_world_error(fresh_sim._dispatch_action("reset", {}), "reset") + + def test_set_gravity_no_world(self, fresh_sim): + self._assert_standard_no_world_error( + fresh_sim._dispatch_action("set_gravity", {"gravity": [0, 0, -1]}), + "set_gravity", + ) + + def test_render_no_world(self, fresh_sim): + # render returns error cleanly when no world, not a crash. + result = fresh_sim._dispatch_action("render", {}) + assert result["status"] == "error" + # render uses the unified message now: + assert "No world" in result["content"][0]["text"] + + def test_get_state_no_world(self, fresh_sim): + self._assert_standard_no_world_error(fresh_sim._dispatch_action("get_state", {}), "get_state") + + +class TestUnifiedNotFoundMessages: + """T15: Unknown-name errors use the consistent ' X not found.' shape.""" + + def test_robot_not_found(self, sim): + result = sim._dispatch_action("get_robot_state", {"robot_name": "ghost_bot"}) + assert result["status"] == "error" + text = result["content"][0]["text"] + assert "Robot 'ghost_bot' not found" in text + + def test_object_not_found(self, sim): + result = sim._dispatch_action("move_object", {"name": "ghost_box", "position": [0, 0, 0]}) + assert result["status"] == "error" + assert "Object 'ghost_box' not found" in result["content"][0]["text"] + + def test_body_not_found(self, sim): + result = sim._dispatch_action("apply_force", {"body_name": "ghost_body", "force": [0, 0, 1]}) + assert result["status"] == "error" + assert "Body 'ghost_body' not found" in result["content"][0]["text"] + + def test_sensor_not_found(self, sim): + result = sim._dispatch_action("get_sensor_data", {"sensor_name": "ghost_sensor"}) + assert result["status"] == "error" + text = result["content"][0]["text"] + # T45 is about distinguishing "no sensors" vs "not found"; at minimum the + # current behaviour must mention the sensor name clearly. + assert "ghost_sensor" in text + + +class TestIdempotentStopFamily: + """T16: stop_recording, stop_cameras_recording, stop_policy and close_viewer + can be called unconditionally - when already stopped they succeed with a + distinguishable 'Was not ...' message.""" + + def test_stop_recording_twice_is_idempotent(self, sim): + r1 = sim.stop_recording() + assert r1["status"] == "success" + r2 = sim.stop_recording() + assert r2["status"] == "success" + assert "Was not recording" in r2["content"][0]["text"] + + def test_stop_cameras_recording_twice_is_idempotent(self, sim): + r1 = sim.stop_cameras_recording() + assert r1["status"] == "success" + r2 = sim.stop_cameras_recording() + assert r2["status"] == "success" + + def test_close_viewer_twice_is_idempotent(self, sim): + # close_viewer was already idempotent - pin it with a regression test. + assert sim.close_viewer()["status"] == "success" + assert sim.close_viewer()["status"] == "success" + + +class TestStopPolicyContract: + """T16 + T24: stop_policy requires a robot_name; is idempotent per robot.""" + + def test_stop_policy_empty_robot_name_friendly_error(self, sim): + r = sim._dispatch_action("stop_policy", {}) + assert r["status"] == "error" + assert "requires" in r["content"][0]["text"].lower() and "robot_name" in r["content"][0]["text"] + + def test_stop_policy_unknown_robot_errors(self, sim): + r = sim._dispatch_action("stop_policy", {"robot_name": "ghost_bot"}) + assert r["status"] == "error" + assert "Robot 'ghost_bot' not found" in r["content"][0]["text"] + + +class TestForwardPassBeforeReads: + """T18/T19: get_mass_matrix, get_contacts run mj_forward first so values + are valid immediately after a reset / add_robot / load_state, not just + after a full mj_step.""" + + def test_get_mass_matrix_after_reset_is_valid(self, sim): + sim.reset() + r = sim._dispatch_action("get_mass_matrix", {}) + assert r["status"] == "success" + # Empty scene: nv==0 so rank==0 and cond==inf are acceptable; the + # important bit is we didn't return NaN / raise. + payload = r["content"][-1].get("json", {}) if isinstance(r["content"][-1], dict) else {} + assert "shape" in payload + + def test_get_contacts_at_t0_no_phantom_penetrations(self, sim): + # Empty world has no contacts; running this at t=0 must succeed + # and return an empty list (T19 used to surface stale/uninit data). + sim.reset() + r = sim._dispatch_action("get_contacts", {}) + assert r["status"] == "success" + payload = r["content"][-1]["json"] if isinstance(r["content"][-1], dict) else {} + contacts = payload.get("contacts", []) + # An empty world has no contacts. If the fix isn't applied and stale + # data surfaces, contacts may contain garbage names/distances. Assert + # either empty or all distances > -1mm (no phantom deep penetrations). + for c in contacts: + assert c["dist"] > -0.001, f"phantom penetration: {c}" + + +class TestRenderDimValidation: + """T20: non-positive width/height rejected; oversized dims get plain-English + message instead of raw MuJoCo framebuffer error.""" + + def test_zero_width_rejected(self, sim): + r = sim._dispatch_action("render", {"width": 0, "height": 120}) + assert r["status"] == "error" + assert "width and height must be > 0" in r["content"][0]["text"] + + def test_negative_height_rejected(self, sim): + r = sim._dispatch_action("render", {"width": 160, "height": -10}) + assert r["status"] == "error" + assert "must be > 0" in r["content"][0]["text"] + + def test_oversize_dim_message_is_friendly(self, sim): + # Request 8000x8000 - well above any sane offscreen framebuffer cap. + r = sim._dispatch_action("render", {"width": 8000, "height": 8000}) + assert r["status"] == "error" + text = r["content"][0]["text"] + assert "exceeds" in text + assert "framebuffer" in text + assert "offwidth" in text # points at the fix + + +class TestRenderDepthSurfaces: + """T21: render_depth mac warning surfaces in the response text when the + driver lacks ARB_clip_control. Skipped when the warning isn't triggered + (Linux / modern macOS GPUs may or may not hit it).""" + + def test_render_depth_returns_well_formed_response(self, sim): + # Just check render_depth runs cleanly; the T21-specific warning + # only fires on macOS without ARB_clip_control so we only assert + # presence-of-warning when _depth_warn_text is set. + r = sim._dispatch_action("render_depth", {}) + # Some headless envs don't have GL: we only care the response shape + # is valid either way. + assert r["status"] in ("success", "error") + if r["status"] == "success": + text = r["content"][0]["text"] + # If a warning was captured, it must be on the response. + warn_cached = getattr(sim, "_depth_warn_text", "") + if warn_cached: + assert warn_cached in text + + +class TestFeatureFilters: + """T32 / T33: forward_kinematics + get_features honor per-entity filters.""" + + def test_forward_kinematics_body_name_filters(self, sim): + # Empty world: world body exists but any custom name is absent. + r = sim._dispatch_action("forward_kinematics", {"body_name": "ghost_body"}) + assert r["status"] == "error" + assert "Body 'ghost_body' not found" in r["content"][0]["text"] + + def test_forward_kinematics_no_filter_returns_all(self, sim): + r = sim._dispatch_action("forward_kinematics", {}) + assert r["status"] == "success" + payload = r["content"][-1]["json"] if isinstance(r["content"][-1], dict) else {} + assert "bodies" in payload + + def test_get_features_unknown_robot_errors(self, sim): + r = sim._dispatch_action("get_features", {"robot_name": "ghost_bot"}) + assert r["status"] == "error" + assert "Robot 'ghost_bot' not found" in r["content"][0]["text"] + + def test_get_features_no_filter_returns_all(self, sim): + r = sim._dispatch_action("get_features", {}) + assert r["status"] == "success" + + +class TestRegisterUrdfValidation: + """T35 / T42: register_urdf validates path + router covers no-args.""" + + def test_register_urdf_no_args_friendly_error(self, sim): + r = sim._dispatch_action("register_urdf", {}) + assert r["status"] == "error" + assert "requires parameter" in r["content"][0]["text"] + + def test_register_urdf_missing_file_errors(self, sim): + r = sim._dispatch_action( + "register_urdf", + {"data_config": "my_bot", "urdf_path": "/nonexistent/nope.urdf"}, + ) + assert r["status"] == "error" + assert "file not found" in r["content"][0]["text"].lower() + + def test_register_urdf_empty_path_errors(self, sim): + r = sim._dispatch_action("register_urdf", {"data_config": "my_bot", "urdf_path": ""}) + assert r["status"] == "error" + # Router handles empty string as missing? No - it's a truthy string + # in the presence test. So we hit our explicit empty guard. + assert "non-empty" in r["content"][0]["text"] or "requires parameter" in r["content"][0]["text"] + + +class TestDuplicateCameraName: + """T30 / T41: add_camera rejects duplicate names instead of silently + overwriting the registry entry while leaving the XML unchanged.""" + + def test_duplicate_camera_rejected(self, sim): + r1 = sim._dispatch_action( + "add_camera", + {"name": "dupe", "position": [0.5, 0.5, 0.5], "target": [0, 0, 0]}, + ) + assert r1["status"] == "success", r1 + r2 = sim._dispatch_action( + "add_camera", + {"name": "dupe", "position": [1, 0, 0], "target": [0, 0, 0]}, + ) + assert r2["status"] == "error" + assert "already exists" in r2["content"][0]["text"] + + +class TestPlaneAutoStatic: + """T29: add_object(shape='plane') auto-sets is_static=True.""" + + def test_plane_default_is_static(self, sim): + r = sim._dispatch_action("add_object", {"name": "floor1", "shape": "plane"}) + assert r["status"] == "success" + assert sim._world.objects["floor1"].is_static is True + + def test_plane_with_explicit_dynamic_errors(self, sim): + r = sim._dispatch_action("add_object", {"name": "bad_floor", "shape": "plane", "is_static": False}) + assert r["status"] == "error" + assert "plane" in r["content"][0]["text"].lower() and "is_static" in r["content"][0]["text"] + + +class TestSetGeomPropertiesAlias: + """T28: set_geom_properties accepts the object name as a stand-in for the + MJCF-injected '{name}_geom' geom name.""" + + def test_object_name_resolves_to_geom(self, sim): + sim._dispatch_action( + "add_object", + {"name": "box_alpha", "shape": "box", "size": [0.05, 0.05, 0.05]}, + ) + # Using the object name, not '{name}_geom', should work - the + # T28 alias resolves to '{name}_geom' internally. + r = sim._dispatch_action("set_geom_properties", {"geom_name": "box_alpha", "color": [1, 0, 0, 1]}) + # Success proves the alias resolved; error with 'Geom not found' would + # mean T28 didn't kick in. + assert r["status"] == "success", r + assert "box_alpha" in r["content"][0]["text"] or "geom" in r["content"][0]["text"].lower() + + +class TestEvalPolicyDefaults: + """T34: eval_policy requires robot_name; n_episodes default is 1.""" + + def test_eval_policy_missing_robot_name_errors(self, sim): + r = sim._dispatch_action("eval_policy", {}) + assert r["status"] == "error" + assert "robot_name" in r["content"][0]["text"] + + def test_eval_policy_unknown_robot_errors(self, sim): + r = sim._dispatch_action("eval_policy", {"robot_name": "ghost"}) + assert r["status"] == "error" + # Either "Robot X not found" (world has robots) or "No robots in sim" + # (empty scene) - both are correct paths. + text = r["content"][0]["text"] + assert "ghost" in text or "No robots" in text + + +class TestRecordingStatusLifecycle: + """T31: get_recording_status succeeds in every state (no world / not + recording / recording) with distinguishing text.""" + + def test_no_world_returns_success(self): + s = Simulation(tool_name="rec_lifecycle_nw", mesh=False) + try: + r = s._dispatch_action("get_recording_status", {}) + assert r["status"] == "success" + assert "No world" in r["content"][0]["text"] + finally: + s.cleanup() + + def test_not_recording_returns_success(self, sim): + r = sim._dispatch_action("get_recording_status", {}) + assert r["status"] == "success" + assert "Not recording" in r["content"][0]["text"] + + +class TestListRobotsPolicyStatus: + """T37: list_robots reports per-robot policy status. Regression ward.""" + + def test_list_robots_shows_idle_when_no_policy(self, sim): + r = sim._dispatch_action("list_robots", {}) + assert r["status"] == "success" + # No robots added, so we just expect the "No robots" message. + assert "No robots" in r["content"][0]["text"] or "🤖" in r["content"][0]["text"] + + +class TestPolicyHorizonUnification: + """T25: run_policy and start_policy accept n_steps (primary) / max_steps + (legacy) as alternatives to duration. duration = n_steps / control_freq.""" + + def test_run_policy_n_steps_zero_errors(self, sim): + r = sim._dispatch_action("run_policy", {"robot_name": "ghost", "n_steps": 0}) + assert r["status"] == "error" + # Either n_steps validation fires first, or robot-not-found; both are + # acceptable error paths - we just want NO silent success. + text = r["content"][0]["text"] + assert ("n_steps" in text and "> 0" in text) or "Robot" in text + + def test_run_policy_negative_n_steps_errors(self, sim): + r = sim._dispatch_action("run_policy", {"robot_name": "ghost", "n_steps": -10}) + assert r["status"] == "error" + + +class TestAddRobotDeprecation: + """T22: the `name`-as-registry-fallback path emits a DeprecationWarning.""" + + def test_add_robot_name_fallback_warns(self, sim): + import warnings + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + # 'mock_never_registered' won't resolve to anything, so the + # fallback is attempted but also fails. We only care the + # warning was triggered in the path. + r = sim._dispatch_action("add_robot", {"name": "mock_never_registered"}) + # Either succeeded (name happened to resolve -> warning) or failed. + # Just verify: if it succeeded via name fallback, a warning fired. + warn_texts = [str(w.message) for w in captured if issubclass(w.category, DeprecationWarning)] + if r["status"] == "success": + assert any("deprecated" in t.lower() for t in warn_texts) + + +class TestMixedDataConfigRobots: + """Regression: robots with different ``data_config`` values can coexist + in one scene even when their MJCFs declare colliding nested default + class names (e.g. ```` in both). + + Pre-fix, adding an ``h1`` humanoid after two ``so100`` arms errored with + MuJoCo's *"repeated default class name"*. Fixed by per-config namespacing + in scene_ops. + """ + + def test_two_arms_plus_humanoid_coexist(self, sim): + r1 = sim.add_robot(name="alice", data_config="so100", position=[-0.6, 0, 0]) + assert r1["status"] == "success", r1["content"][0].get("text") + r2 = sim.add_robot(name="bob", data_config="so100", position=[0.6, 0, 0]) + assert r2["status"] == "success", r2["content"][0].get("text") + r3 = sim.add_robot(name="carol", data_config="h1", position=[0, 1.0, 0]) + assert r3["status"] == "success", r3["content"][0].get("text") + assert set(sim._world.robots.keys()) == {"alice", "bob", "carol"} + + def test_four_different_configs_coexist(self, sim): + specs = [ + ("alice", "so100", [-0.6, 0, 0]), + ("bob", "so100", [0.6, 0, 0]), + ("carol", "h1", [0, 1.0, 0]), + ("dan", "panda", [0, -1.0, 0]), + ] + for name, cfg, pos in specs: + r = sim.add_robot(name=name, data_config=cfg, position=pos) + assert r["status"] == "success", f"add_robot({name}, {cfg}) failed: {r['content'][0].get('text')}" + r = sim.step(n_steps=5) + assert r["status"] == "success" + # Ensure the physics actually advanced (forward kinematics would be + # blocked by any lingering compile error). + assert abs(sim._world.sim_time - 0.010) < 1e-9 + + +class TestRemoveRobotActuallyRemoves: + """Regression: remove_robot used to only pop the Python dict entry; + the robot's MJCF bodies/actuators/sensors stayed in the compiled model. + That blocked re-adding the same name and left stale DOFs consuming + physics time per step. + """ + + def test_remove_robot_empties_model(self, sim): + r = sim.add_robot(name="alice", data_config="so100") + assert r["status"] == "success" + njnt_before = sim._world._model.njnt + assert njnt_before > 0, "precondition: robot should have added joints" + + r = sim.remove_robot(name="alice") + assert r["status"] == "success" + assert sim._world._model.njnt == 0 + assert sim._world._model.nbody == 1 # just the world root body + assert "alice" not in sim._world.robots + + def test_readd_same_name_after_remove(self, sim): + """Adding a robot, removing it, then adding again with the same name + must succeed (MuJoCo rejects duplicate body names otherwise).""" + assert sim.add_robot(name="alice", data_config="so100")["status"] == "success" + assert sim.remove_robot(name="alice")["status"] == "success" + r = sim.add_robot(name="alice", data_config="so100") + assert r["status"] == "success", r["content"][0].get("text") + assert sim._world._model.njnt == 6 # so100 has 6 joints + + def test_remove_middle_of_three_robots(self, sim): + sim.add_robot(name="alice", data_config="so100", position=[-0.5, 0, 0]) + sim.add_robot(name="bob", data_config="so100", position=[0.5, 0, 0]) + sim.add_robot(name="carol", data_config="h1", position=[0, 1, 0]) + njnt_before = sim._world._model.njnt + + r = sim.remove_robot(name="bob") + assert r["status"] == "success" + assert set(sim._world.robots) == {"alice", "carol"} + # bob was 6 joints; alice (6) + carol (19) = 25 should remain. + assert sim._world._model.njnt == njnt_before - 6 diff --git a/tests/simulation/mujoco/test_backend.py b/tests/simulation/mujoco/test_backend.py new file mode 100644 index 0000000..aec613c --- /dev/null +++ b/tests/simulation/mujoco/test_backend.py @@ -0,0 +1,141 @@ +"""Unit tests for mujoco/backend.py - GL backend auto-configuration.""" + +from __future__ import annotations + +import os +import sys +from unittest.mock import patch + +import pytest + +from strands_robots.simulation.mujoco import backend as backend_mod + + +@pytest.fixture +def restore_env(monkeypatch): + """Isolate MUJOCO_GL / DISPLAY / WAYLAND_DISPLAY per test.""" + for var in ("MUJOCO_GL", "DISPLAY", "WAYLAND_DISPLAY"): + monkeypatch.delenv(var, raising=False) + yield monkeypatch + + +class TestIsHeadless: + """``_is_headless`` only returns True on Linux with no display server.""" + + def test_non_linux_is_not_headless(self, restore_env): + with patch.object(sys, "platform", "darwin"): + assert backend_mod._is_headless() is False + + def test_linux_with_display_not_headless(self, restore_env): + restore_env.setenv("DISPLAY", ":0") + with patch.object(sys, "platform", "linux"): + assert backend_mod._is_headless() is False + + def test_linux_with_wayland_not_headless(self, restore_env): + restore_env.setenv("WAYLAND_DISPLAY", "wayland-0") + with patch.object(sys, "platform", "linux"): + assert backend_mod._is_headless() is False + + def test_linux_no_display_is_headless(self, restore_env): + with patch.object(sys, "platform", "linux"): + assert backend_mod._is_headless() is True + + +class TestConfigureGLBackend: + """``_configure_gl_backend`` respects MUJOCO_GL and probes EGL then OSMesa.""" + + def test_respects_user_mujoco_gl(self, restore_env): + restore_env.setenv("MUJOCO_GL", "glfw") + backend_mod._configure_gl_backend() + # Value unchanged. + assert os.environ["MUJOCO_GL"] == "glfw" + + def test_noop_on_non_headless(self, restore_env): + with patch.object(sys, "platform", "darwin"): + backend_mod._configure_gl_backend() + # Nothing was set. + assert "MUJOCO_GL" not in os.environ + + def test_headless_picks_egl_when_available(self, restore_env): + with ( + patch.object(sys, "platform", "linux"), + patch("strands_robots.simulation.mujoco.backend.ctypes.cdll.LoadLibrary") as load, + ): + load.side_effect = [None] + try: + backend_mod._configure_gl_backend() + assert os.environ.get("MUJOCO_GL") == "egl" + load.assert_called_once() + finally: + # explicit teardown - monkeypatch.delenv only covers vars it had seen at yield time + os.environ.pop("MUJOCO_GL", None) + + def test_headless_falls_back_to_osmesa(self, restore_env): + with ( + patch.object(sys, "platform", "linux"), + patch("strands_robots.simulation.mujoco.backend.ctypes.cdll.LoadLibrary") as load, + ): + load.side_effect = [OSError("no libEGL"), None] + try: + backend_mod._configure_gl_backend() + assert os.environ.get("MUJOCO_GL") == "osmesa" + assert load.call_count == 2 + finally: + os.environ.pop("MUJOCO_GL", None) + + def test_headless_without_any_gl_warns(self, restore_env, caplog): + import logging + + with ( + patch.object(sys, "platform", "linux"), + patch("strands_robots.simulation.mujoco.backend.ctypes.cdll.LoadLibrary") as load, + ): + load.side_effect = OSError("no GL") + with caplog.at_level(logging.WARNING, logger="strands_robots.simulation.mujoco.backend"): + backend_mod._configure_gl_backend() + # MUJOCO_GL stays unset. + assert "MUJOCO_GL" not in os.environ + # Warning text lists both libraries. + assert any("EGL" in rec.message and "OSMesa" in rec.message for rec in caplog.records) + + +class TestCanRender: + """``_can_render`` caches the probe result and short-circuits on headless+no-GL.""" + + def _clear_cache(self): + backend_mod._rendering_available = None + + def test_returns_cached_value(self): + self._clear_cache() + backend_mod._rendering_available = True + assert backend_mod._can_render() is True + + backend_mod._rendering_available = False + assert backend_mod._can_render() is False + self._clear_cache() + + def test_headless_without_mujoco_gl_short_circuits(self, restore_env): + """Probe must NOT run when headless+no-GL - otherwise GLFW SIGABRTs.""" + self._clear_cache() + with patch.object(sys, "platform", "linux"): + # No DISPLAY, no MUJOCO_GL. + assert backend_mod._can_render() is False + # Cached result remembers the negative. + assert backend_mod._rendering_available is False + self._clear_cache() + + +class TestEnsureMujoco: + """``_ensure_mujoco`` returns a module-like object with MjModel/MjData.""" + + def test_returns_module(self): + mj = backend_mod._ensure_mujoco() + # Smoke: these attributes must exist on the real module. + assert hasattr(mj, "MjModel") + assert hasattr(mj, "MjData") + assert hasattr(mj, "mj_step") + + def test_is_cached(self): + first = backend_mod._ensure_mujoco() + second = backend_mod._ensure_mujoco() + assert first is second diff --git a/tests/simulation/mujoco/test_concurrency.py b/tests/simulation/mujoco/test_concurrency.py new file mode 100644 index 0000000..654518d --- /dev/null +++ b/tests/simulation/mujoco/test_concurrency.py @@ -0,0 +1,1116 @@ +"""Regression tests for PR #85 review feedback. + +Tests: +1. Thread-safety: concurrent dispatch + policy doesn't corrupt state +2. Flat-index state copy: joint positions survive object injection +3. apply_force: force is latched (persists across steps) +4. Camera recording roundtrip: namespaced cameras survive schema reconcile + +Run: MUJOCO_GL=osmesa python -m pytest tests/test_mujoco_regressions.py -v +""" + +import math +import os +import shutil +import tempfile +import threading +import time + +import numpy as np +import pytest + +mj = pytest.importorskip("mujoco") + +from strands_robots.simulation.mujoco.backend import _can_render # noqa: E402 +from strands_robots.simulation.mujoco.simulation import Simulation # noqa: E402 + +requires_gl = pytest.mark.skipif( + not _can_render(), + reason="No OpenGL context available (headless without EGL/OSMesa)", +) + +# Test robot XML (simple 3-DOF arm) + +ROBOT_XML = """ + + + +""" + + +@pytest.fixture +def robot_xml_path(): + """Write test robot XML to a temp file.""" + tmpdir = tempfile.mkdtemp() + path = os.path.join(tmpdir, "test_arm.xml") + with open(path, "w") as f: + f.write(ROBOT_XML) + yield path + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.fixture +def sim_with_robot(robot_xml_path): + """Simulation with world + robot loaded.""" + sim = Simulation(tool_name="test_regression", mesh=False) + result = sim.create_world(gravity=[0, 0, -9.81]) + assert result["status"] == "success" + result = sim.add_robot("arm1", urdf_path=robot_xml_path) + assert result["status"] == "success" + yield sim + sim.cleanup() + + +class TestFlatIndexStatePreservation: + """Regression: joint positions must survive object injection (layout shift).""" + + def test_joint_survives_object_injection(self, sim_with_robot): + """Set a joint to π/3, inject an object, verify joint is still ≈π/3. + + This catches the flat-index qpos copy bug where injected bodies + shift existing qpos entries. + """ + sim = sim_with_robot + target_angle = math.pi / 3 + + # Set shoulder_pan to π/3 + result = sim.set_joint_positions( + positions={"shoulder_pan": target_angle}, + robot_name="arm1", + ) + assert result["status"] == "success" + + # Verify it's set + state = sim.get_robot_state("arm1") + assert state["status"] == "success" # state returned + # Read qpos directly + model = sim._world._model + jid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, "arm1/shoulder_pan") + if jid < 0: + jid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, "shoulder_pan") + assert jid >= 0 + qpos_before = float(sim._world._data.qpos[model.jnt_qposadr[jid]]) + assert abs(qpos_before - target_angle) < 1e-6 + + # Inject an object (triggers XML round-trip + _reload_scene_from_xml) + result = sim.add_object( + "test_box", + shape="box", + position=[0.5, 0.5, 0.1], + size=[0.05, 0.05, 0.05], + ) + assert result["status"] == "success" + + # Verify joint is still ≈π/3 after injection + model = sim._world._model + jid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, "arm1/shoulder_pan") + if jid < 0: + jid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, "shoulder_pan") + assert jid >= 0 + qpos_after = float(sim._world._data.qpos[model.jnt_qposadr[jid]]) + assert abs(qpos_after - target_angle) < 1e-4, ( + f"Joint drifted from {target_angle:.6f} to {qpos_after:.6f} after object injection" + ) + + +class TestApplyForceLatchedBehavior: + """Regression: apply_force is latched (persists across steps).""" + + def test_force_persists_across_multiple_steps(self, sim_with_robot): + """Apply lateral force to a body, step 50 times, verify body moved. + + This validates the docstring contract: force is latched in + qfrc_applied and applied on every subsequent step. + + NOTE: We use an X-force (lateral) because a Z-force along the + kinematic chain of hinge joints produces zero generalized torque + (mj_applyFT maps Cartesian force to joint space; Z-force at CoM + compresses the chain without creating torques on Y-axis hinges). + """ + sim = sim_with_robot + + # Get initial x position of link2 + model = sim._world._model + data = sim._world._data + body_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "arm1/link2") + if body_id < 0: + body_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "link2") + assert body_id >= 0 + + x_before = float(data.xpos[body_id, 0]) + + # Apply strong lateral (X) force - this creates torques on Y-axis hinges + result = sim.apply_force("link2", force=[100.0, 0, 0]) + assert result["status"] == "success" + + # Step physics 50 times - force should persist (latched) + sim.step(n_steps=50) + + x_after = float(data.xpos[body_id, 0]) + # Body should have moved laterally due to persistent force + assert abs(x_after - x_before) > 1e-4, ( + f"Body did not move (x_before={x_before:.6f}, x_after={x_after:.6f}). " + "Force may not be persisting across steps." + ) + + def test_zero_force_stops_effect(self, sim_with_robot): + """Apply force, then zero it, verify force buffer is cleared.""" + sim = sim_with_robot + + # Apply lateral (X) force - produces non-zero generalized torques + sim.apply_force("link2", force=[50.0, 0, 0]) + assert np.any(sim._world._data.qfrc_applied != 0), "X-force on link2 should produce non-zero generalized forces" + + # Zero it - apply_force zeros buffer first, then applies zero force + sim.apply_force("link2", force=[0, 0, 0]) + # After zeroing + applying zero force/torque, buffer should be all zeros + assert np.allclose(sim._world._data.qfrc_applied, 0.0) + + +class TestThreadSafety: + """Regression: concurrent operations don't corrupt MuJoCo state.""" + + def test_concurrent_step_and_reset_no_crash(self, sim_with_robot): + """Concurrent step() and reset() must not SIGSEGV. + + Both acquire self._lock, so they serialize. This test verifies + the lock is actually held (no segfault, no exception). + """ + sim = sim_with_robot + errors = [] + + def stepper(): + try: + for _ in range(100): + sim.step(n_steps=1) + time.sleep(0.001) + except Exception as e: + errors.append(f"stepper: {e}") + + def resetter(): + try: + for _ in range(10): + sim.reset() + time.sleep(0.01) + except Exception as e: + errors.append(f"resetter: {e}") + + t1 = threading.Thread(target=stepper) + t2 = threading.Thread(target=resetter) + t1.start() + t2.start() + t1.join(timeout=10) + t2.join(timeout=10) + + assert not errors, f"Thread errors: {errors}" + + def test_concurrent_set_joint_and_step(self, sim_with_robot): + """Concurrent set_joint_positions and step must serialize safely.""" + sim = sim_with_robot + errors = [] + + def setter(): + try: + for i in range(50): + sim.set_joint_positions( + positions={"shoulder_pan": float(i) * 0.01}, + robot_name="arm1", + ) + time.sleep(0.001) + except Exception as e: + errors.append(f"setter: {e}") + + def stepper(): + try: + for _ in range(50): + sim.step(n_steps=2) + time.sleep(0.001) + except Exception as e: + errors.append(f"stepper: {e}") + + t1 = threading.Thread(target=setter) + t2 = threading.Thread(target=stepper) + t1.start() + t2.start() + t1.join(timeout=10) + t2.join(timeout=10) + + assert not errors, f"Thread errors: {errors}" + + +# Robot XML for multi-robot asset directory test + +ROBOT_B_XML = """ + + + + + + + + + + + + +""" + + +class TestRecordingRoundtripCameraFrames: + """Regression: namespaced cameras survive schema reconcile and have frames. + + @yinsong1986 review (2026-04-30): "Please add a round-trip test: + start_recording → run_policy → stop_recording, reopen the dataset, + assert the camera feature has non-zero frames." + """ + + @pytest.fixture + def sim_with_namespaced_camera(self, robot_xml_path, tmp_path): + """Sim with a robot whose camera name contains '/' (namespace).""" + sim = Simulation(tool_name="test_recording", mesh=False) + result = sim.create_world(gravity=[0, 0, -9.81]) + assert result["status"] == "success" + result = sim.add_robot("arm1", urdf_path=robot_xml_path) + assert result["status"] == "success" + yield sim + sim.cleanup() + + @requires_gl + def test_recording_roundtrip_has_camera_frames(self, sim_with_namespaced_camera, tmp_path): + """Record → run mock policy → stop → verify dataset has camera data. + + This validates the /→__ sanitization fix doesn't silently drop frames. + The test robot XML has camera 'arm0/wrist_cam' which becomes + 'arm0__wrist_cam' in the dataset schema. + """ + pytest.importorskip("lerobot") + from pathlib import Path + + sim = sim_with_namespaced_camera + ds_root = str(tmp_path / "roundtrip_ds") + + # Start recording + result = sim._dispatch_action( + "start_recording", + {"repo_id": "local/rt-test", "root": ds_root, "fps": 10, "overwrite": True}, + ) + assert result["status"] == "success", f"start_recording failed: {result}" + + # Run mock policy for a short burst (generates frames via on_frame hook) + result = sim._dispatch_action( + "run_policy", + { + "robot_name": "arm1", + "policy_provider": "mock", + "duration": 0.5, + "control_frequency": 10, + }, + ) + assert result["status"] == "success", f"run_policy failed: {result}" + + # Stop recording + result = sim._dispatch_action("stop_recording", {}) + assert result["status"] == "success", f"stop_recording failed: {result}" + + # Verify dataset exists and has frames + ds_path = Path(ds_root) + assert ds_path.exists(), f"Dataset dir not created at {ds_root}" + + # Reopen dataset and verify camera feature has frames + try: + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + ds = LeRobotDataset(repo_id="local/rt-test", root=ds_root) + except (ImportError, RuntimeError): + pytest.skip("lerobot dataset API not available (torchcodec/ffmpeg missing)") + + assert len(ds) > 0, f"Dataset has no frames (expected > 0, got {len(ds)})" + + # Check that the camera feature exists (sanitized name) + cam_feature_found = False + for feat_name in ds.features: + if feat_name.startswith("observation.images."): + cam_feature_found = True + break + + assert cam_feature_found, ( + f"No observation.images.* feature found in dataset. Features: {list(ds.features.keys())}" + ) + + # Access a frame and verify image data is present (requires ffmpeg for video decode) + try: + sample = ds[0] + for feat_name in ds.features: + if feat_name.startswith("observation.images."): + assert feat_name in sample, f"Camera feature {feat_name} missing from sample" + img = sample[feat_name] + # Image should be non-empty (tensor or array with shape) + assert hasattr(img, "shape"), f"Camera data has no shape: {type(img)}" + assert img.shape[0] > 0, f"Camera image has zero height: {img.shape}" + break + except RuntimeError: + # torchcodec requires system FFmpeg libraries for video decode + pass + + +class TestMultiRobotDifferentAssetDirs: + """Regression: two robots from different asset dirs both compile and render. + + @yinsong1986 review (2026-04-30): "load two robots whose urdf_paths + are in different directories; assert both render." + """ + + def test_two_robots_different_directories_both_load(self): + """Load two robots from separate temp dirs, verify both have joints.""" + tmpdir_a = tempfile.mkdtemp(prefix="robot_a_") + tmpdir_b = tempfile.mkdtemp(prefix="robot_b_") + + try: + # Write robot A (arm) to dir A + path_a = os.path.join(tmpdir_a, "arm.xml") + with open(path_a, "w") as f: + f.write(ROBOT_XML) + + # Write robot B (gripper) to dir B + path_b = os.path.join(tmpdir_b, "gripper.xml") + with open(path_b, "w") as f: + f.write(ROBOT_B_XML) + + sim = Simulation(tool_name="test_multi_asset", mesh=False) + result = sim.create_world(gravity=[0, 0, -9.81]) + assert result["status"] == "success" + + # Add robot A from dir A + result = sim.add_robot("arm1", urdf_path=path_a) + assert result["status"] == "success", f"Robot A failed: {result}" + + # Add robot B from dir B (different asset directory) + result = sim.add_robot("grip1", urdf_path=path_b, position=[0.3, 0, 0]) + assert result["status"] == "success", f"Robot B failed: {result}" + + # Both robots should be registered + assert "arm1" in sim._world.robots + assert "grip1" in sim._world.robots + + # Both should have joints discovered + assert len(sim._world.robots["arm1"].joint_names) == 3 # shoulder_pan, shoulder_lift, elbow + assert len(sim._world.robots["grip1"].joint_names) == 1 # grip_slide + + # Physics step should succeed (proves combined model compiled) + result = sim.step(n_steps=10) + assert result["status"] == "success", f"Step failed: {result}" + + # Verify we can read state from both robots + state_a = sim.get_robot_state("arm1") + assert state_a["status"] == "success", f"State A failed: {state_a}" + state_b = sim.get_robot_state("grip1") + assert state_b["status"] == "success", f"State B failed: {state_b}" + + sim.cleanup() + finally: + shutil.rmtree(tmpdir_a, ignore_errors=True) + shutil.rmtree(tmpdir_b, ignore_errors=True) + + @requires_gl + def test_two_robots_both_render_cameras(self): + """Two robots with cameras from different dirs - both cameras render.""" + # Robot A has arm0/wrist_cam (from ROBOT_XML) + # Add a camera to Robot B as well + robot_b_with_cam = """ + + + + + + + + + + + + + +""" + tmpdir_a = tempfile.mkdtemp(prefix="robot_a_cam_") + tmpdir_b = tempfile.mkdtemp(prefix="robot_b_cam_") + + try: + path_a = os.path.join(tmpdir_a, "arm.xml") + with open(path_a, "w") as f: + f.write(ROBOT_XML) + + path_b = os.path.join(tmpdir_b, "gripper_cam.xml") + with open(path_b, "w") as f: + f.write(robot_b_with_cam) + + sim = Simulation(tool_name="test_render_multi", mesh=False) + result = sim.create_world(gravity=[0, 0, -9.81]) + assert result["status"] == "success" + + result = sim.add_robot("arm1", urdf_path=path_a) + assert result["status"] == "success" + result = sim.add_robot("grip1", urdf_path=path_b, position=[0.5, 0, 0]) + assert result["status"] == "success" + + # Step to settle physics + sim.step(n_steps=5) + + # Get observation (includes camera renders) + obs = sim._get_sim_observation("arm1") + + # We should have at least one camera rendered (arm0/wrist_cam) + cam_frames = {k: v for k, v in obs.items() if isinstance(v, np.ndarray) and v.ndim == 3} + assert len(cam_frames) > 0, f"No camera frames rendered. Observation keys: {list(obs.keys())}" + + # Verify camera frame is not all-zero (actually rendered something) + for cam_name, frame in cam_frames.items(): + assert frame.shape[2] == 3, f"Camera {cam_name} not RGB: shape={frame.shape}" + # At minimum, the frame should have some non-zero pixels + # (ground plane + colored geoms should provide contrast) + assert frame.sum() > 0, f"Camera {cam_name} rendered all-black frame" + + sim.cleanup() + finally: + shutil.rmtree(tmpdir_a, ignore_errors=True) + shutil.rmtree(tmpdir_b, ignore_errors=True) + + +class TestSceneMutationBlockedDuringPolicy: + """Scene mutations must hard-fail while a policy is running. + + A concurrent PolicyRunner worker calling mj_step on stale model/data + pointers (swapped by XML round-trip in add_object, add_camera, etc.) + is undefined behaviour. The guard ensures agents learn to stop_policy + before modifying the scene. + """ + + @pytest.fixture + def robot_path(self, tmp_path): + """Write test robot XML to a temp file.""" + path = tmp_path / "arm.xml" + path.write_text(ROBOT_XML) + return str(path) + + def test_add_object_blocked_during_policy(self, robot_path): + sim = Simulation(tool_name="test_guard_obj", mesh=False) + result = sim.create_world(gravity=[0, 0, -9.81]) + assert result["status"] == "success" + + result = sim.add_robot("arm1", urdf_path=robot_path) + assert result["status"] == "success" + + # Start a policy (fast_mode so it completes quickly after stop) + result = sim.start_policy("arm1", policy_provider="mock", duration=2.0, fast_mode=True) + assert result["status"] == "success" + + # Try adding an object while policy is running - should be blocked + result = sim.add_object("cube", shape="box", position=[0.3, 0, 0.05]) + assert result["status"] == "error" + assert "policy is running" in result["content"][0]["text"].lower() + + # Stop the policy + sim.stop_policy("arm1") + if "arm1" in sim._policy_threads: + sim._policy_threads["arm1"].result(timeout=10.0) + + # Now it should work + result = sim.add_object("cube", shape="box", position=[0.3, 0, 0.05]) + assert result["status"] == "success" + + sim.cleanup() + + def test_add_camera_blocked_during_policy(self, robot_path): + sim = Simulation(tool_name="test_guard_cam", mesh=False) + result = sim.create_world(gravity=[0, 0, -9.81]) + assert result["status"] == "success" + + result = sim.add_robot("arm1", urdf_path=robot_path) + assert result["status"] == "success" + + result = sim.start_policy("arm1", policy_provider="mock", duration=2.0, fast_mode=True) + assert result["status"] == "success" + + # Try adding a camera while policy is running - should be blocked + result = sim.add_camera("top_cam", position=[0, 0, 2], target=[0, 0, 0]) + assert result["status"] == "error" + assert "policy is running" in result["content"][0]["text"].lower() + + sim.stop_policy("arm1") + if "arm1" in sim._policy_threads: + sim._policy_threads["arm1"].result(timeout=10.0) + + result = sim.add_camera("top_cam", position=[0, 0, 2], target=[0, 0, 0]) + assert result["status"] == "success" + + sim.cleanup() + + def test_load_scene_blocked_during_policy(self, robot_path): + sim = Simulation(tool_name="test_guard_scene", mesh=False) + result = sim.create_world(gravity=[0, 0, -9.81]) + assert result["status"] == "success" + + result = sim.add_robot("arm1", urdf_path=robot_path) + assert result["status"] == "success" + + result = sim.start_policy("arm1", policy_provider="mock", duration=2.0, fast_mode=True) + assert result["status"] == "success" + + # load_scene while policy is running - should be blocked + result = sim.load_scene(robot_path) + assert result["status"] == "error" + assert "policy is running" in result["content"][0]["text"].lower() + + sim.stop_policy("arm1") + if "arm1" in sim._policy_threads: + sim._policy_threads["arm1"].result(timeout=10.0) + + sim.cleanup() + + def test_move_object_blocked_during_policy(self, robot_path): + sim = Simulation(tool_name="test_guard_move", mesh=False) + result = sim.create_world(gravity=[0, 0, -9.81]) + assert result["status"] == "success" + + result = sim.add_robot("arm1", urdf_path=robot_path) + assert result["status"] == "success" + + # Add an object to move later + result = sim.add_object("cube", shape="box", position=[0.3, 0, 0.05]) + assert result["status"] == "success" + + result = sim.start_policy("arm1", policy_provider="mock", duration=2.0, fast_mode=True) + assert result["status"] == "success" + + # Try moving an object while policy is running - should be blocked + result = sim.move_object("cube", position=[0.5, 0, 0.1]) + assert result["status"] == "error" + assert "policy is running" in result["content"][0]["text"].lower() + + sim.stop_policy("arm1") + if "arm1" in sim._policy_threads: + sim._policy_threads["arm1"].result(timeout=10.0) + + # Now it should work + result = sim.move_object("cube", position=[0.5, 0, 0.1]) + assert result["status"] == "success" + + sim.cleanup() + + def test_remove_robot_stops_own_policy_and_succeeds(self, robot_path): + """Per-robot scoping (GH #114): remove_robot(X) gracefully stops X's + own policy before removing it. Previously this errored, forcing the + agent into a two-step stop-then-remove dance even in the common + 'delete the robot I'm running' case. + """ + sim = Simulation(tool_name="test_guard_remove_robot", mesh=False) + result = sim.create_world(gravity=[0, 0, -9.81]) + assert result["status"] == "success" + + result = sim.add_robot("arm1", urdf_path=robot_path) + assert result["status"] == "success" + + result = sim.start_policy("arm1", policy_provider="mock", duration=2.0, fast_mode=True) + assert result["status"] == "success" + + # GH #114: remove_robot on the same arm gracefully stops its policy + # and proceeds. No two-step dance required. + result = sim.remove_robot("arm1") + assert result["status"] == "success", result + assert "arm1" in result["content"][0]["text"] + # Policy future was pruned. + assert "arm1" not in sim._policy_threads + + sim.cleanup() + + def test_remove_robot_blocked_by_OTHER_robot_policy(self, robot_path): + """Global-scope guard (GH #114): remove_robot(A) still errors if + a policy is active on a different robot B, because the XML round-trip + invalidates cached actuator/joint IDs held by B's PolicyRunner. + """ + sim = Simulation(tool_name="test_guard_other_robot", mesh=False) + assert sim.create_world(gravity=[0, 0, -9.81])["status"] == "success" + assert sim.add_robot("armA", urdf_path=robot_path)["status"] == "success" + assert sim.add_robot("armB", urdf_path=robot_path)["status"] == "success" + + # Policy on B... + assert sim.start_policy("armB", policy_provider="mock", duration=5.0, fast_mode=True)["status"] == "success" + + # ...blocks remove_robot on A (scene mutation invalidates IDs). + result = sim.remove_robot("armA") + assert result["status"] == "error" + assert "policy is running" in result["content"][0]["text"].lower() + assert "armB" in result["content"][0]["text"] + + sim.stop_policy("armB") + if "armB" in sim._policy_threads: + sim._policy_threads["armB"].result(timeout=10.0) + + # Now removal works. + assert sim.remove_robot("armA")["status"] == "success" + + sim.cleanup() + + +class TestConcurrentPerRobotPolicies: + """GH #114: two or more policies can run concurrently on different robots. + + Proves the post-fix semantics: + + * ``start_policy`` only blocks on the SAME robot; a second start_policy + on a DIFFERENT robot while the first is running now succeeds. + * ``list_policies_running`` accurately reports all active ones and + prunes completed Futures as a side-effect. + * Two policies mutating their own ``ctrl[]`` slots in parallel never + corrupt MuJoCo state (``self._lock`` still serializes ``mj_step``). + """ + + @pytest.fixture + def robot_path(self, tmp_path): + path = tmp_path / "arm.xml" + path.write_text(ROBOT_XML) + return str(path) + + def test_start_policy_allowed_on_second_robot_while_first_runs(self, robot_path): + sim = Simulation(tool_name="test_concurrent_start", mesh=False) + assert sim.create_world()["status"] == "success" + assert sim.add_robot("armA", urdf_path=robot_path)["status"] == "success" + assert sim.add_robot("armB", urdf_path=robot_path)["status"] == "success" + + # First policy starts. + r1 = sim.start_policy("armA", policy_provider="mock", duration=3.0, fast_mode=True) + assert r1["status"] == "success", r1 + + # Second policy on a DIFFERENT robot also starts (per-robot gate). + r2 = sim.start_policy("armB", policy_provider="mock", duration=3.0, fast_mode=True) + assert r2["status"] == "success", r2 + + # Both active. + active = sim._active_policy_robots() + assert set(active) == {"armA", "armB"}, active + + sim.stop_policy("armA") + sim.stop_policy("armB") + # Wait for graceful stop. + for name in ("armA", "armB"): + fut = sim._policy_threads.get(name) + if fut is not None: + try: + fut.result(timeout=10.0) + except Exception: + pass + sim.cleanup() + + def test_start_policy_still_rejected_on_SAME_robot(self, robot_path): + """Per-robot gate still fires when we start twice on the same robot.""" + sim = Simulation(tool_name="test_concurrent_same", mesh=False) + assert sim.create_world()["status"] == "success" + assert sim.add_robot("arm1", urdf_path=robot_path)["status"] == "success" + + r1 = sim.start_policy("arm1", policy_provider="mock", duration=3.0, fast_mode=True) + assert r1["status"] == "success" + + r2 = sim.start_policy("arm1", policy_provider="mock", duration=3.0, fast_mode=True) + assert r2["status"] == "error" + assert "arm1" in r2["content"][0]["text"] + + sim.stop_policy("arm1") + fut = sim._policy_threads.get("arm1") + if fut is not None: + try: + fut.result(timeout=10.0) + except Exception: + pass + sim.cleanup() + + def test_list_policies_running_reports_active(self, robot_path): + sim = Simulation(tool_name="test_list_policies", mesh=False) + sim.create_world() + sim.add_robot("armA", urdf_path=robot_path) + sim.add_robot("armB", urdf_path=robot_path) + + # None active. + r = sim.list_policies_running() + assert r["status"] == "success" + assert "No policies" in r["content"][0]["text"] + + # One active. + sim.start_policy("armA", policy_provider="mock", duration=3.0, fast_mode=True) + r = sim.list_policies_running() + assert r["status"] == "success" + assert "armA" in r["content"][0]["text"] + assert "armB" not in r["content"][0]["text"] + + # Two active. + sim.start_policy("armB", policy_provider="mock", duration=3.0, fast_mode=True) + r = sim.list_policies_running() + assert "armA" in r["content"][0]["text"] + assert "armB" in r["content"][0]["text"] + + # Clean shutdown. + sim.stop_policy("armA") + sim.stop_policy("armB") + for name in ("armA", "armB"): + fut = sim._policy_threads.get(name) + if fut is not None: + try: + fut.result(timeout=10.0) + except Exception: + pass + + # After both stop, list is empty again (stale prune). + r = sim.list_policies_running() + assert "No policies" in r["content"][0]["text"] + assert sim._policy_threads == {} + + sim.cleanup() + + def test_completed_futures_are_pruned(self, robot_path): + """GH #120 (companion fix): completed Futures must not linger in + _policy_threads forever. + """ + sim = Simulation(tool_name="test_prune", mesh=False) + sim.create_world() + sim.add_robot("armA", urdf_path=robot_path) + + # Very short policy - let it complete naturally. + sim.start_policy("armA", policy_provider="mock", duration=0.1, fast_mode=True) + fut = sim._policy_threads.get("armA") + assert fut is not None + try: + fut.result(timeout=10.0) + except Exception: + pass + + # Future is done - one introspection call prunes it. + active = sim._active_policy_robots() + assert active == [], active + assert "armA" not in sim._policy_threads + + sim.cleanup() + + def test_scene_mutation_lists_which_robots_are_running(self, robot_path): + """Error message names the active-policy robots so the LLM can + stop_policy on each without guessing. + """ + sim = Simulation(tool_name="test_err_msg", mesh=False) + sim.create_world() + sim.add_robot("armA", urdf_path=robot_path) + sim.add_robot("armB", urdf_path=robot_path) + + sim.start_policy("armA", policy_provider="mock", duration=3.0, fast_mode=True) + sim.start_policy("armB", policy_provider="mock", duration=3.0, fast_mode=True) + + r = sim.set_gravity([0, 0, -5.0]) + assert r["status"] == "error" + text = r["content"][0]["text"] + assert "armA" in text + assert "armB" in text + + sim.stop_policy("armA") + sim.stop_policy("armB") + for name in ("armA", "armB"): + fut = sim._policy_threads.get(name) + if fut is not None: + try: + fut.result(timeout=10.0) + except Exception: + pass + sim.cleanup() + + def test_two_policies_no_segfault_under_stress(self, robot_path): + """Smoke test: two concurrent policies actually *run* (not just + both "started") and produce step_count > 0 on both robots, with + self._lock serializing the shared mj_step safely. + + Uses a short duration + fast_mode so the test finishes under + a second. + """ + sim = Simulation(tool_name="test_stress_concurrent", mesh=False) + sim.create_world() + sim.add_robot("armA", urdf_path=robot_path) + sim.add_robot("armB", urdf_path=robot_path) + + sim.start_policy("armA", policy_provider="mock", duration=0.5, fast_mode=True) + sim.start_policy("armB", policy_provider="mock", duration=0.5, fast_mode=True) + + # Let both run to completion. + for name in ("armA", "armB"): + fut = sim._policy_threads.get(name) + if fut is not None: + try: + fut.result(timeout=15.0) + except Exception: + pass + + # Both robots advanced their step counter - proves both ran. + assert sim._world is not None + assert sim._world.robots["armA"].policy_steps > 0, "armA never stepped - concurrent scheduling broke it" + assert sim._world.robots["armB"].policy_steps > 0, "armB never stepped - concurrent scheduling broke it" + + sim.cleanup() + + +class TestCleanupGracefulShutdown: + """GH #116: cleanup() must wait for live policy workers before nulling + the world, otherwise an in-flight mj_step segfaults on freed arrays. + """ + + @pytest.fixture + def robot_path(self, tmp_path): + path = tmp_path / "arm.xml" + path.write_text(ROBOT_XML) + return str(path) + + def test_cleanup_awaits_running_policy(self, robot_path): + """Start a long-running policy, call cleanup, verify the worker + completed (Future.done()) before cleanup returned and we do NOT + segfault on world nulling.""" + sim = Simulation(tool_name="test_cleanup_await", mesh=False) + sim.create_world() + sim.add_robot("armA", urdf_path=robot_path) + + sim.start_policy("armA", policy_provider="mock", duration=5.0, fast_mode=True) + fut = sim._policy_threads.get("armA") + assert fut is not None and not fut.done(), "policy should be live" + + # Cleanup with tight timeout - the cooperative-stop flag is read + # every step so 1s is plenty for MockPolicy to exit. + sim.cleanup(policy_stop_timeout=2.0) + + # Post-cleanup invariants. + assert fut.done(), "Future must have terminated before cleanup returned" + assert sim._world is None, "world must be nulled after cleanup" + assert sim._policy_threads == {}, "policy_threads must be drained" + + def test_cleanup_tolerates_wedged_policy(self, robot_path): + """A policy that refuses to stop within the timeout must NOT hang + the whole process. Cleanup logs a warning and proceeds.""" + sim = Simulation(tool_name="test_cleanup_wedged", mesh=False) + sim.create_world() + sim.add_robot("armA", urdf_path=robot_path) + + sim.start_policy("armA", policy_provider="mock", duration=5.0, fast_mode=True) + + # Aggressively short timeout forces the "wedged" path even if the + # mock is fast - the test is that cleanup RETURNS in bounded time, + # not that the future is done. + import time as _time + + t0 = _time.monotonic() + sim.cleanup(policy_stop_timeout=0.001) + elapsed = _time.monotonic() - t0 + + # Even with timeout=1ms, total cleanup must complete quickly. + # We allow some slack for teardown of renderers/viewer. + assert elapsed < 10.0, f"cleanup blocked too long: {elapsed:.2f}s" + assert sim._world is None + + def test_cleanup_is_idempotent_with_no_policies(self, robot_path): + """Calling cleanup with no live policies must be a straight no-op + for the policy-drain path (no Futures to wait on).""" + sim = Simulation(tool_name="test_cleanup_noop", mesh=False) + sim.create_world() + sim.add_robot("armA", urdf_path=robot_path) + # No start_policy call. + + sim.cleanup(policy_stop_timeout=0.1) + + assert sim._world is None + assert sim._policy_threads == {} + + def test_cleanup_drains_multiple_concurrent_policies(self, robot_path): + """With concurrent per-robot policies (GH #114), cleanup must await + BOTH before nulling the world.""" + sim = Simulation(tool_name="test_cleanup_multi", mesh=False) + sim.create_world() + sim.add_robot("armA", urdf_path=robot_path) + sim.add_robot("armB", urdf_path=robot_path) + + sim.start_policy("armA", policy_provider="mock", duration=5.0, fast_mode=True) + sim.start_policy("armB", policy_provider="mock", duration=5.0, fast_mode=True) + + futs = {name: sim._policy_threads.get(name) for name in ("armA", "armB")} + assert all(f is not None and not f.done() for f in futs.values()) + + sim.cleanup(policy_stop_timeout=3.0) + + # Both worker futures settled before cleanup returned. + for name, fut in futs.items(): + assert fut is not None and fut.done(), f"'{name}' future was not awaited" + assert sim._world is None + + +class TestMutationGuardStress: + """GH #119: hammer the mutation guard to prove no race between + the ``_require_no_running_policy`` check and the PolicyRunner's + ``mj_step`` call. Historically we relied on the check being 'atomic + enough in practice' - no test proved it. + + The critical contract we're validating: + + 1. Every scene-mutation call attempted while a policy is live must + either (a) return status=error with our uniform message, or + (b) return status=success if the policy has already settled. + NOTHING may corrupt MuJoCo state or segfault. + + 2. The mutation guard must be fast enough that 1000 concurrent + requests from the main thread do not starve the policy worker. + """ + + @pytest.fixture + def robot_path(self, tmp_path): + path = tmp_path / "arm.xml" + path.write_text(ROBOT_XML) + return str(path) + + def test_1000_set_gravity_calls_during_policy_never_segfault(self, robot_path): + """Start a policy, then bang set_gravity 1000 times from the main + thread. Every call must return a well-formed dict - no crash, no + half-applied mutation. Once the policy ends, the last set_gravity + succeeds.""" + sim = Simulation(tool_name="test_stress_set_gravity", mesh=False) + sim.create_world() + sim.add_robot("arm", urdf_path=robot_path) + + sim.start_policy("arm", policy_provider="mock", duration=1.0, fast_mode=True) + + # Hammer from the main thread while the worker runs. + blocked = 0 + succeeded = 0 + for _ in range(1000): + r = sim.set_gravity([0.0, 0.0, -9.81]) + assert isinstance(r, dict), r + assert r["status"] in ("success", "error"), r + if r["status"] == "error": + assert "policy is running" in r["content"][0]["text"].lower() + blocked += 1 + else: + succeeded += 1 + + # At least one call must have been blocked (policy was live). + assert blocked > 0, "stress loop never saw the policy as live - timing broken" + + # After policy finishes, set_gravity works. + fut = sim._policy_threads.get("arm") + if fut is not None: + try: + fut.result(timeout=10.0) + except Exception: + pass + + result = sim.set_gravity([0.0, 0.0, -5.0]) + assert result["status"] == "success" + + sim.cleanup(policy_stop_timeout=2.0) + + def test_rapid_start_stop_start_stop_policy(self, robot_path): + """Stress the Future lifecycle. Rapid start/stop cycles must leave + _policy_threads in a consistent state every iteration.""" + sim = Simulation(tool_name="test_rapid_cycle", mesh=False) + sim.create_world() + sim.add_robot("arm", urdf_path=robot_path) + + for i in range(10): + r_start = sim.start_policy("arm", policy_provider="mock", duration=2.0, fast_mode=True) + assert r_start["status"] == "success", (i, r_start) + + r_stop = sim.stop_policy("arm") + assert r_stop["status"] == "success", (i, r_stop) + + # Await worker so the next start_policy doesn't race. + fut = sim._policy_threads.get("arm") + if fut is not None: + try: + fut.result(timeout=5.0) + except Exception: + pass + + # Prune runs as a side effect of _active_policy_robots. + active = sim._active_policy_robots() + assert active == [], (i, active) + + sim.cleanup(policy_stop_timeout=2.0) + + def test_mutation_accepted_immediately_after_policy_completes(self, robot_path): + """Once the policy Future is done(), the VERY NEXT scene mutation + must succeed - no lingering guard state from the just-completed run.""" + sim = Simulation(tool_name="test_no_lingering_guard", mesh=False) + sim.create_world() + sim.add_robot("arm", urdf_path=robot_path) + + # Very short policy. + sim.start_policy("arm", policy_provider="mock", duration=0.05, fast_mode=True) + fut = sim._policy_threads.get("arm") + assert fut is not None + try: + fut.result(timeout=5.0) + except Exception: + pass + assert fut.done() + + # First mutation after completion must succeed. + r = sim.set_gravity([0.0, 0.0, -9.81]) + assert r["status"] == "success", r + + sim.cleanup(policy_stop_timeout=1.0) + + def test_concurrent_policies_stress_no_deadlock(self, robot_path): + """Two concurrent policies (GH #114) + main-thread mutation spam + must not deadlock on self._lock.""" + sim = Simulation(tool_name="test_concurrent_stress", mesh=False) + sim.create_world() + sim.add_robot("armA", urdf_path=robot_path) + sim.add_robot("armB", urdf_path=robot_path) + + sim.start_policy("armA", policy_provider="mock", duration=1.0, fast_mode=True) + sim.start_policy("armB", policy_provider="mock", duration=1.0, fast_mode=True) + + blocked = 0 + errors = 0 + for _ in range(500): + r = sim.set_gravity([0.0, 0.0, -9.81]) + assert r["status"] in ("success", "error"), r + if r["status"] == "error": + # When blocked, the message must name AT LEAST one robot. + text = r["content"][0]["text"] + if "armA" in text or "armB" in text: + blocked += 1 + else: + errors += 1 + + assert errors == 0, f"unexpected error shape: {errors}" + assert blocked > 0, "never caught policies as live" + + # Wait for both to settle. + for name in ("armA", "armB"): + fut = sim._policy_threads.get(name) + if fut is not None: + try: + fut.result(timeout=10.0) + except Exception: + pass + + sim.cleanup(policy_stop_timeout=2.0) diff --git a/tests/simulation/mujoco/test_e2e.py b/tests/simulation/mujoco/test_e2e.py new file mode 100644 index 0000000..582c050 --- /dev/null +++ b/tests/simulation/mujoco/test_e2e.py @@ -0,0 +1,314 @@ +"""End-to-end MuJoCo simulation test with Policy ABC. + +Tests the full observe → policy → act → step → render pipeline +without requiring strands SDK or lerobot - just mujoco + numpy. + +Run: python -m pytest tests/test_mujoco_e2e.py -v +""" + +import asyncio +import os +import shutil +import tempfile + +import numpy as np +import pytest + +# Skip entire module if mujoco not installed +mj = pytest.importorskip("mujoco") + + +def _has_opengl() -> bool: + """Check if OpenGL rendering is available.""" + try: + model = mj.MjModel.from_xml_string("") + renderer = mj.Renderer(model, height=1, width=1) + del renderer + return True + except Exception: + return False + + +requires_gl = pytest.mark.skipif( + not _has_opengl(), + reason="No OpenGL context available (headless environment without EGL/OSMesa)", +) + + +from strands_robots.policies import MockPolicy # noqa: E402 +from strands_robots.simulation.base import SimEngine # noqa: E402 +from strands_robots.simulation.models import SimObject, SimRobot, SimStatus, SimWorld # noqa: E402 + +# Fixtures + +ROBOT_XML = """ + + + +""" + + +@pytest.fixture +def sim_env(): + """Create a MuJoCo model+data from test XML.""" + tmpdir = tempfile.mkdtemp() + xml_path = os.path.join(tmpdir, "test_arm.xml") + with open(xml_path, "w") as f: + f.write(ROBOT_XML) + + model = mj.MjModel.from_xml_path(xml_path) + data = mj.MjData(model) + + yield model, data + + shutil.rmtree(tmpdir, ignore_errors=True) + + +JOINT_NAMES = ["shoulder_pan", "shoulder_lift", "elbow"] + + +def read_joints(model, data): + obs = {} + for jname in JOINT_NAMES: + jid = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, jname) + obs[jname] = float(data.qpos[model.jnt_qposadr[jid]]) + return obs + + +def apply_action(model, data, action_dict): + for key, val in action_dict.items(): + act_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_ACTUATOR, f"{key}_act") + if act_id >= 0: + data.ctrl[act_id] = val + + +# Tests + + +class TestSimulationBase: + def test_abc_has_required_methods(self): + required = [ + "create_world", + "destroy", + "reset", + "step", + "get_state", + "add_robot", + "remove_robot", + "add_object", + "remove_object", + "get_observation", + "send_action", + "render", + ] + for method in required: + assert hasattr(SimEngine, method) + + def test_shared_dataclasses(self): + w = SimWorld() + assert w.timestep == 0.002 + assert w.gravity == [0.0, 0.0, -9.81] + assert w.status == SimStatus.IDLE + + r = SimRobot(name="test", urdf_path="/tmp/test.urdf") + assert r.joint_names == [] + + o = SimObject(name="cube", shape="box") + assert o.mass == 0.1 + + +class TestMuJoCoPhysics: + def test_step_advances_time(self, sim_env): + model, data = sim_env + assert data.time == 0.0 + for _ in range(100): + mj.mj_step(model, data) + assert data.time == pytest.approx(0.2, abs=1e-6) + + def test_position_actuators_move_joints(self, sim_env): + model, data = sim_env + data.ctrl[0] = 1.0 # shoulder_pan target + for _ in range(1000): + mj.mj_step(model, data) + obs = read_joints(model, data) + assert abs(obs["shoulder_pan"] - 1.0) < 0.15 + + def test_contacts_detected(self, sim_env): + model, data = sim_env + for _ in range(100): + mj.mj_step(model, data) + assert data.ncon > 0 # cube on ground + + def test_reset_zeros_time(self, sim_env): + model, data = sim_env + for _ in range(100): + mj.mj_step(model, data) + mj.mj_resetData(model, data) + assert data.time == 0.0 + + +@requires_gl +class TestMuJoCoRendering: + def test_render_rgb(self, sim_env): + model, data = sim_env + mj.mj_forward(model, data) + renderer = mj.Renderer(model, height=240, width=320) + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, "front") + renderer.update_scene(data, camera=cam_id) + img = renderer.render() + assert img.shape == (240, 320, 3) + assert img.dtype == np.uint8 + assert img.max() > 0 + del renderer + + def test_render_depth(self, sim_env): + model, data = sim_env + mj.mj_forward(model, data) + renderer = mj.Renderer(model, height=120, width=160) + renderer.update_scene(data) + renderer.enable_depth_rendering() + depth = renderer.render() + renderer.disable_depth_rendering() + assert depth.shape == (120, 160) + assert depth.max() > 0 + del renderer + + +class TestMockPolicyLoop: + def test_mock_policy_generates_actions(self): + policy = MockPolicy() + policy.set_robot_state_keys(JOINT_NAMES) + obs = {j: 0.0 for j in JOINT_NAMES} + actions = asyncio.run(policy.get_actions(obs, "test")) + assert len(actions) == 8 + assert all(j in actions[0] for j in JOINT_NAMES) + + def test_full_observe_act_loop(self, sim_env): + model, data = sim_env + policy = MockPolicy() + policy.set_robot_state_keys(JOINT_NAMES) + + for step in range(20): + obs = read_joints(model, data) + actions = asyncio.run(policy.get_actions(obs, "pick up cube")) + apply_action(model, data, actions[0]) + mj.mj_step(model, data) + + assert data.time > 0 + final_obs = read_joints(model, data) + # Joints should have moved from 0 + assert any(abs(v) > 0.001 for v in final_obs.values()) + + @requires_gl + def test_loop_with_rendering(self, sim_env): + """Full loop: observe → policy → act → step → render (10 iterations).""" + model, data = sim_env + policy = MockPolicy() + policy.set_robot_state_keys(JOINT_NAMES) + renderer = mj.Renderer(model, height=120, width=160) + + frames = [] + for _ in range(10): + obs = read_joints(model, data) + actions = asyncio.run(policy.get_actions(obs, "wave")) + apply_action(model, data, actions[0]) + mj.mj_step(model, data) + + renderer.update_scene(data) + frames.append(renderer.render().copy()) + + assert len(frames) == 10 + assert all(f.shape == (120, 160, 3) for f in frames) + # Frames should differ (robot is moving) + assert not np.array_equal(frames[0], frames[-1]) + del renderer + + +class TestDomainRandomization: + def test_color_randomization(self, sim_env): + model, data = sim_env + orig = model.geom_rgba.copy() + rng = np.random.default_rng(42) + for i in range(model.ngeom): + model.geom_rgba[i, :3] = rng.uniform(0.1, 1.0, size=3) + assert not np.array_equal(orig, model.geom_rgba) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + + +class TestToolSpecActionCoverage: + """Verify every action enum in tool_spec.json maps to a real method on Simulation.""" + + def test_all_actions_have_methods(self): + """Every action in tool_spec.json must resolve to a method on Simulation.""" + import json + from pathlib import Path + + from strands_robots.simulation.mujoco.simulation import Simulation + + spec_path = Path(__file__).resolve().parents[3] / "strands_robots" / "simulation" / "mujoco" / "tool_spec.json" + with open(spec_path) as f: + spec = json.load(f) + + actions = spec["properties"]["action"]["enum"] + assert len(actions) > 0, "tool_spec.json should have at least one action" + + # Aliases used by _dispatch_action + aliases = { + "list_robots": "list_robots_info", + } + + missing = [] + for action in actions: + method_name = aliases.get(action, action) + if not hasattr(Simulation, method_name): + missing.append(f"{action} (looked for method '{method_name}')") + + assert not missing, "tool_spec.json actions with no matching Simulation method:\n" + "\n".join( + f" - {m}" for m in missing + ) + + def test_action_enum_is_not_empty(self): + """Sanity: tool_spec.json action enum is populated.""" + import json + from pathlib import Path + + spec_path = Path(__file__).resolve().parents[3] / "strands_robots" / "simulation" / "mujoco" / "tool_spec.json" + with open(spec_path) as f: + spec = json.load(f) + + actions = spec["properties"]["action"]["enum"] + assert len(actions) >= 30, f"Expected ≥30 actions, got {len(actions)}" diff --git a/tests/simulation/mujoco/test_error_paths.py b/tests/simulation/mujoco/test_error_paths.py new file mode 100644 index 0000000..39e273b --- /dev/null +++ b/tests/simulation/mujoco/test_error_paths.py @@ -0,0 +1,348 @@ +"""Error-path coverage for MuJoCo ``Simulation`` public methods. + +Every public method should return ``{"status": "error", ...}`` (never raise) +for: +* invalid identifiers (unknown body/geom/joint/sensor names) +* out-of-bounds numeric ids +* missing-arg edge cases (None positions, None velocities, etc.) +* ghost checkpoints / ghost cameras / idle policy stop +* pathological shape params (negative timestep, short gravity vector) + +This locks the AgentTool contract: the LLM-facing surface must never bubble +a raw exception. +""" + +from __future__ import annotations + +import os +import shutil +import tempfile + +import pytest + +mj = pytest.importorskip("mujoco") + +os.environ.setdefault("MUJOCO_GL", "glfw") + +# Inline robot XML - avoids network dependency on robot model repos +_ROBOT_XML = """ + + + +""" + + +@pytest.fixture +def ready_sim(): + from strands_robots.simulation import Simulation + + tmpdir = tempfile.mkdtemp() + path = os.path.join(tmpdir, "test_arm.xml") + with open(path, "w") as f: + f.write(_ROBOT_XML) + + s = Simulation() + s.create_world(timestep=0.002) + result = s.add_robot("arm", urdf_path=path, position=[0.0, 0.0, 0.0]) + assert result["status"] == "success", f"add_robot failed: {result}" + s.step(n_steps=5) + yield s + s.destroy() + shutil.rmtree(tmpdir, ignore_errors=True) + + +# ─ Physics: unknown-name + out-of-bounds──────────────────────────── + + +def test_set_geom_properties_out_of_bounds_id_errors_gracefully(ready_sim): + r = ready_sim.set_geom_properties(geom_id=999999, color=[1, 0, 0, 1]) + assert r["status"] == "error" + assert "not found" in r["content"][0]["text"] + + +def test_set_geom_properties_unknown_name_errors_gracefully(ready_sim): + r = ready_sim.set_geom_properties(geom_name="__does_not_exist__", color=[1, 0, 0, 1]) + assert r["status"] == "error" + assert "not found" in r["content"][0]["text"] + + +def test_set_body_properties_unknown_name_errors_gracefully(ready_sim): + r = ready_sim.set_body_properties(body_name="__ghost_body__", mass=1.0) + assert r["status"] == "error" + + +def test_get_jacobian_unknown_body_errors(ready_sim): + r = ready_sim.get_jacobian(body_name="__no_such_body__") + assert r["status"] == "error" + + +def test_get_jacobian_unknown_site_errors(ready_sim): + r = ready_sim.get_jacobian(site_name="__no_such_site__") + assert r["status"] == "error" + + +def test_get_jacobian_unknown_geom_errors(ready_sim): + r = ready_sim.get_jacobian(geom_name="__no_such_geom__") + assert r["status"] == "error" + + +def test_set_joint_positions_none_dict_errors(ready_sim): + # Post-T11: message updated to explain list OR dict is accepted. + r = ready_sim.set_joint_positions(positions=None) + assert r["status"] == "error" + assert "'positions' is required" in r["content"][0]["text"] + + +def test_set_joint_velocities_none_dict_errors(ready_sim): + # Post-T11: message updated to explain list OR dict is accepted. + r = ready_sim.set_joint_velocities(velocities=None) + assert r["status"] == "error" + assert "'velocities' is required" in r["content"][0]["text"] + + +def test_set_joint_positions_unknown_joint_is_skipped_not_raised(ready_sim): + """Unknown joint names are logged and skipped - not fatal.""" + joints = ready_sim.robot_joint_names("arm") + assert len(joints) > 0, "Fixture robot must have joints" + r = ready_sim.set_joint_positions(positions={joints[0]: 0.1, "__nope__": 0.2}) + assert r["status"] == "success" # the valid joint still applied + + +def test_apply_force_torque_only(ready_sim): + """apply_force with torque-only (force=None) should still succeed.""" + r = ready_sim.apply_force(body_name="arm/base", torque=[0.0, 0.0, 0.1]) + assert r["status"] == "success" + + +def test_apply_force_unknown_body_errors(ready_sim): + r = ready_sim.apply_force(body_name="__ghost__", force=[1, 0, 0]) + assert r["status"] == "error" + + +def test_get_sensor_data_no_sensors_returns_info(ready_sim): + """Test arm has no sensors → returns success with an informational text.""" + r = ready_sim.get_sensor_data() + assert r["status"] == "success" + assert "No sensors" in r["content"][0]["text"] + + +def test_get_sensor_data_unknown_name_errors(ready_sim): + """T45: requesting a specific sensor name on a model with no sensors must + report a clear 'not found' error (distinguishable from 'no sensors at all' + when no name was given). + """ + r = ready_sim.get_sensor_data(sensor_name="__ghost_sensor__") + assert r["status"] == "error" + text = r["content"][0]["text"] + assert "__ghost_sensor__" in text + assert "not found" in text + + +def test_get_body_state_unknown_body_errors(ready_sim): + r = ready_sim.get_body_state(body_name="__ghost__") + assert r["status"] == "error" + + +# ─ State mgmt: ghost checkpoints─────────────────────────────────── + + +def test_load_state_unknown_checkpoint_errors(ready_sim): + r = ready_sim.load_state(name="__never_saved__") + assert r["status"] == "error" + + +def test_save_state_then_load_state_round_trips(ready_sim): + r = ready_sim.save_state(name="probe") + assert r["status"] == "success" + r = ready_sim.load_state(name="probe") + assert r["status"] == "success" + + +# ─ Scene mutations: ghosts────────────────────────────────────────── + + +def test_remove_robot_ghost_errors(ready_sim): + r = ready_sim.remove_robot("__never_added__") + assert r["status"] == "error" + + +def test_remove_object_ghost_errors(ready_sim): + r = ready_sim.remove_object("__never_added__") + assert r["status"] == "error" + + +def test_remove_camera_ghost_errors(ready_sim): + r = ready_sim.remove_camera("__never_added__") + assert r["status"] == "error" + + +def test_move_object_ghost_errors(ready_sim): + r = ready_sim.move_object(name="__ghost__", position=[0, 0, 0.1]) + assert r["status"] == "error" + + +# ─ Policy lifecycle───────────────────────────────────────────────── + + +def test_stop_policy_on_idle_robot_errors(ready_sim): + """stop_policy on a robot that isn't running a policy is a no-op error.""" + r = ready_sim.stop_policy("arm") + # Some implementations may return "success" with a no-op message; the + # contract is: no exception, a dict back, and the flag ends up cleared. + assert isinstance(r, dict) + assert r.get("status") in ("success", "error") + + +def test_stop_policy_ghost_robot_errors(ready_sim): + r = ready_sim.stop_policy("__ghost_robot__") + assert r["status"] == "error" + + +# ─ World controls──────────────────────────────────────────────── + + +def test_step_zero_is_noop(ready_sim): + t_pre = ready_sim._world.sim_time + r = ready_sim.step(n_steps=0) + assert r["status"] == "success" + assert ready_sim._world.sim_time == t_pre + + +def test_reset_after_perturbation_restores_time(ready_sim): + ready_sim.step(n_steps=20) + assert ready_sim._world.sim_time > 0 + r = ready_sim.reset() + assert r["status"] == "success" + + +def test_set_gravity_scalar(ready_sim): + """A scalar is interpreted as downward gravity.""" + r = ready_sim.set_gravity(-9.8) + assert r["status"] == "success" + + +def test_set_gravity_3_vector(ready_sim): + r = ready_sim.set_gravity([0.0, 0.0, -3.7]) + assert r["status"] == "success" + + +def test_set_timestep_positive(ready_sim): + r = ready_sim.set_timestep(0.004) + assert r["status"] == "success" + + +# ─ Rendering: unknown camera, render-unavailable paths────────── + + +def test_render_all_with_only_missing_cameras_errors(ready_sim): + """Explicit camera list that matches nothing returns an error.""" + r = ready_sim.render_all(cameras=["ghost_cam_a", "ghost_cam_b"]) + assert r["status"] == "error" + + +def test_render_unknown_camera_falls_back(ready_sim): + """Unknown camera_name → fallback renders with the default view.""" + r = ready_sim.render(camera_name="__not_a_camera__", width=32, height=24) + # MuJoCo falls back to a free camera when cam_id < 0 - should succeed + # unless GL context is unavailable, in which case error is acceptable + assert r["status"] in ("success", "error") + + +# ─ Tool-spec dispatch: unknown action + error routing─────────── + + +def test_dispatch_private_action_is_rejected(ready_sim): + """Dispatcher must refuse private leading-underscore names.""" + r = ready_sim._dispatch_action("_stop_policy", {"action": "_stop_policy"}) + assert r["status"] == "error" + assert "Unknown action" in r["content"][0]["text"] + + +def test_dispatch_field_remap_checkpoint_name_to_name(ready_sim): + """The dispatcher remaps ``checkpoint_name`` → ``name`` for save_state.""" + r = ready_sim._dispatch_action("save_state", {"action": "save_state", "checkpoint_name": "remap_probe"}) + assert r["status"] == "success" + r = ready_sim._dispatch_action("load_state", {"action": "load_state", "checkpoint_name": "remap_probe"}) + assert r["status"] == "success" + + +# ── Properties ───────────────────────────────────────────────────── + + +def test_mj_model_and_mj_data_return_none_before_world(): + """Direct MuJoCo handles are ``None`` until ``create_world`` runs.""" + from strands_robots.simulation import Simulation + + s = Simulation() + assert s.mj_model is None + assert s.mj_data is None + s.destroy() + + +def test_mj_model_and_mj_data_after_world(ready_sim): + """After ``create_world + add_robot`` the handles are populated.""" + import mujoco as mj + + assert isinstance(ready_sim.mj_model, mj.MjModel) + assert isinstance(ready_sim.mj_data, mj.MjData) + + +# ── Observation edge cases (ABC path in Simulation.get_observation) ── + + +def test_get_observation_no_world_returns_empty_dict(): + from strands_robots.simulation import Simulation + + s = Simulation() + assert s.get_observation() == {} + s.destroy() + + +def test_get_observation_no_robots_returns_empty_dict(): + """``get_observation()`` with no robots added yet → ``{}`` (not a raise).""" + from strands_robots.simulation import Simulation + + s = Simulation() + s.create_world() + assert s.get_observation() == {} + s.destroy() + + +def test_get_observation_unknown_robot_returns_empty_dict(ready_sim): + assert ready_sim.get_observation(robot_name="__ghost__") == {} + + +def test_send_action_no_world_is_noop(): + from strands_robots.simulation import Simulation + + s = Simulation() + # Should return None and not raise + assert s.send_action({"j": 0.1}) is None + s.destroy() + + +def test_send_action_unknown_robot_is_noop(ready_sim): + assert ready_sim.send_action({"j": 0.1}, robot_name="__ghost__") is None diff --git a/tests/simulation/mujoco/test_input_validation.py b/tests/simulation/mujoco/test_input_validation.py new file mode 100644 index 0000000..95929db --- /dev/null +++ b/tests/simulation/mujoco/test_input_validation.py @@ -0,0 +1,436 @@ +"""Input validation regression tests for PR #85 fixes (T7, T9, T10). + +These guard against silent data-integrity bugs and process-killing MuJoCo +aborts that were caught by autonomous local testing on PR #85. +""" + +import pytest + +pytest.importorskip("mujoco") + +from strands_robots.simulation.mujoco.simulation import Simulation + + +@pytest.fixture +def sim_with_world(): + """A minimal simulation with an empty world for validation tests.""" + sim = Simulation() + sim.create_world() + yield sim + sim.destroy() + + +@pytest.fixture +def sim_with_robot(): + """A simulation with a single robot for physics-validation tests.""" + sim = Simulation() + sim.create_world() + # Use a built-in registry robot - no network I/O + res = sim.add_robot(name="panda", data_config="panda") + if res["status"] != "success": + pytest.skip(f"panda not available: {res['content'][0]['text']}") + sim.reset() + yield sim + sim.destroy() + + +# --- T9: step validation -------------------------------------------------- + + +class TestStepValidation: + def test_step_negative_errors(self, sim_with_world): + """step(n_steps=-5) must error and NOT decrement step_count.""" + initial = sim_with_world._world.step_count + res = sim_with_world.step(n_steps=-5) + assert res["status"] == "error" + assert "n_steps must be >= 0" in res["content"][0]["text"] + assert sim_with_world._world.step_count == initial, "step_count must not change on rejected call" + + def test_step_zero_is_noop(self, sim_with_world): + """step(n_steps=0) is a successful no-op.""" + initial = sim_with_world._world.step_count + res = sim_with_world.step(n_steps=0) + assert res["status"] == "success" + assert "no-op" in res["content"][0]["text"].lower() + assert sim_with_world._world.step_count == initial + + def test_step_positive_still_works(self, sim_with_world): + """Baseline: non-negative n_steps continues to work.""" + res = sim_with_world.step(n_steps=3) + assert res["status"] == "success" + assert sim_with_world._world.step_count == 3 + + +# --- T7: raycast zero-direction guard ------------------------------------- + + +class TestRaycastValidation: + def test_zero_direction_errors_not_crash(self, sim_with_robot): + """raycast with zero direction used to abort the interpreter. Now errors cleanly.""" + res = sim_with_robot.raycast(origin=[0, 0, 1], direction=[0, 0, 0]) + assert res["status"] == "error" + assert "zero-length" in res["content"][0]["text"].lower() + + def test_wrong_length_direction_errors(self, sim_with_robot): + res = sim_with_robot.raycast(origin=[0, 0, 1], direction=[0, 0]) + assert res["status"] == "error" + assert "3 elements" in res["content"][0]["text"] + + def test_wrong_length_origin_errors(self, sim_with_robot): + res = sim_with_robot.raycast(origin=[0, 0], direction=[0, 0, 1]) + assert res["status"] == "error" + assert "3 elements" in res["content"][0]["text"] + + def test_valid_raycast_still_works(self, sim_with_robot): + res = sim_with_robot.raycast(origin=[0, 0, 5], direction=[0, 0, -1]) + assert res["status"] == "success" + + def test_multi_raycast_zero_direction_isolates_error(self, sim_with_robot): + """A zero-length direction in one ray must not abort the whole batch.""" + res = sim_with_robot.multi_raycast( + origin=[0, 0, 5], + directions=[[0, 0, -1], [0, 0, 0], [1, 0, -1]], + ) + assert res["status"] == "success" + # The JSON payload should show error on ray[1] only + rays = res["content"][1]["json"]["rays"] + assert len(rays) == 3 + assert rays[1].get("error") is not None + assert "zero-length" in rays[1]["error"] + + +# --- T10: apply_force must reject missing-both -------------------------- + + +class TestApplyForceValidation: + def test_missing_both_force_and_torque_errors(self, sim_with_robot): + """apply_force(body='link1') with no force/torque must error, not silent success.""" + res = sim_with_robot.apply_force(body_name="link1") + assert res["status"] == "error" + assert "at least one" in res["content"][0]["text"].lower() + + def test_explicit_zero_force_still_clears_latched(self, sim_with_robot): + """Regression: apply_force(body, force=[0,0,0]) is the documented way to clear.""" + # First latch a force + r1 = sim_with_robot.apply_force(body_name="link1", force=[10, 0, 0]) + assert r1["status"] == "success" + # Then clear with explicit zero - this MUST remain valid + r2 = sim_with_robot.apply_force(body_name="link1", force=[0, 0, 0]) + assert r2["status"] == "success" + + def test_wrong_length_force_errors(self, sim_with_robot): + res = sim_with_robot.apply_force(body_name="link1", force=[1, 2]) + assert res["status"] == "error" + assert "3-element" in res["content"][0]["text"] + + +# --- T8: negative/invalid mass, timestep ------------------------------- + + +class TestMassAndTimestepValidation: + def test_set_body_properties_negative_mass_errors(self, sim_with_robot): + res = sim_with_robot.set_body_properties(body_name="link1", mass=-1.0) + assert res["status"] == "error" + assert "must be > 0" in res["content"][0]["text"] + + def test_set_body_properties_zero_mass_errors(self, sim_with_robot): + res = sim_with_robot.set_body_properties(body_name="link1", mass=0.0) + assert res["status"] == "error" + + def test_set_body_properties_positive_mass_works(self, sim_with_robot): + res = sim_with_robot.set_body_properties(body_name="link1", mass=2.5) + assert res["status"] == "success" + + def test_set_timestep_negative_errors(self, sim_with_world): + res = sim_with_world.set_timestep(-0.01) + assert res["status"] == "error" + assert "> 0" in res["content"][0]["text"] + + def test_set_timestep_zero_errors(self, sim_with_world): + res = sim_with_world.set_timestep(0) + assert res["status"] == "error" + + def test_set_timestep_positive_works(self, sim_with_world): + res = sim_with_world.set_timestep(0.001) + assert res["status"] == "success" + + def test_set_timestep_large_warns_but_succeeds(self, sim_with_world): + res = sim_with_world.set_timestep(0.5) + assert res["status"] == "success" + assert "⚠️" in res["content"][0]["text"] or "unusually" in res["content"][0]["text"] + + +# --- T38: set_gravity dim validation ----------------------------------- + + +class TestSetGravityValidation: + def test_two_element_gravity_errors(self, sim_with_world): + res = sim_with_world.set_gravity([0.0, 0.0]) + assert res["status"] == "error" + assert "3-element" in res["content"][0]["text"] + + def test_scalar_gravity_still_works(self, sim_with_world): + # Scalar form convenience (z-only) preserved + res = sim_with_world.set_gravity(-9.81) + assert res["status"] == "success" + + def test_full_vector_gravity_works(self, sim_with_world): + res = sim_with_world.set_gravity([1.0, 2.0, -9.0]) + assert res["status"] == "success" + + +# --- T11: set_joint_positions list/dict support ----------------------- + + +class TestSetJointPositionsForms: + def test_dict_form_works(self, sim_with_robot): + # Pick a valid joint name from the robot + joint_names = list(sim_with_robot._world.robots.values())[0].joint_names or [] + if not joint_names: + import pytest as _pytest + + _pytest.skip("robot has no named joints") + res = sim_with_robot.set_joint_positions(positions={joint_names[0]: 0.1}) + assert res["status"] == "success" + + def test_list_form_matches_count(self, sim_with_robot): + joint_names = list(sim_with_robot._world.robots.values())[0].joint_names or [] + if not joint_names: + import pytest as _pytest + + _pytest.skip("robot has no named joints") + res = sim_with_robot.set_joint_positions(positions=[0.0] * len(joint_names)) + assert res["status"] == "success", res["content"][0]["text"] + + def test_list_form_wrong_length_errors(self, sim_with_robot): + # 999 is almost certainly wrong for any robot + res = sim_with_robot.set_joint_positions(positions=[0.1] * 999) + assert res["status"] == "error" + assert "does not match" in res["content"][0]["text"] + + +# --- T5: policy-running guards ----------------------------------------- + + +class TestPolicyRunningGuards: + """Simulate policy-running state by poisoning _policy_threads. + + We insert a fake Future whose done() returns False so _require_no_running_policy + flags a running policy without actually starting one. + """ + + def _install_fake_running_policy(self, sim): + class _FakeRunningFuture: + def done(self): + return False + + sim._policy_threads["fake"] = _FakeRunningFuture() + + def test_reset_blocked(self, sim_with_robot): + self._install_fake_running_policy(sim_with_robot) + res = sim_with_robot.reset() + assert res["status"] == "error" + assert "while a policy is running" in res["content"][0]["text"] + + def test_set_gravity_blocked(self, sim_with_robot): + self._install_fake_running_policy(sim_with_robot) + res = sim_with_robot.set_gravity([0, 0, -5]) + assert res["status"] == "error" + assert "while a policy is running" in res["content"][0]["text"] + + def test_set_timestep_blocked(self, sim_with_robot): + self._install_fake_running_policy(sim_with_robot) + res = sim_with_robot.set_timestep(0.001) + assert res["status"] == "error" + assert "while a policy is running" in res["content"][0]["text"] + + def test_set_joint_positions_blocked(self, sim_with_robot): + self._install_fake_running_policy(sim_with_robot) + res = sim_with_robot.set_joint_positions(positions={"nope": 0.0}) + assert res["status"] == "error" + assert "while a policy is running" in res["content"][0]["text"] + + def test_apply_force_blocked(self, sim_with_robot): + self._install_fake_running_policy(sim_with_robot) + res = sim_with_robot.apply_force(body_name="link1", force=[1, 0, 0]) + assert res["status"] == "error" + assert "while a policy is running" in res["content"][0]["text"] + + def test_set_body_properties_blocked(self, sim_with_robot): + self._install_fake_running_policy(sim_with_robot) + res = sim_with_robot.set_body_properties(body_name="link1", mass=3.0) + assert res["status"] == "error" + assert "while a policy is running" in res["content"][0]["text"] + + def test_randomize_blocked(self, sim_with_robot): + self._install_fake_running_policy(sim_with_robot) + res = sim_with_robot.randomize(seed=42) + assert res["status"] == "error" + assert "while a policy is running" in res["content"][0]["text"] + + +# --- T6: add_robot initial state is zero ------------------------------- + + +class TestAddRobotInitialState: + """After add_robot, qpos/qvel/ctrl must be zero without needing reset().""" + + def test_initial_qpos_is_zero(self): + import numpy as np + + sim = Simulation() + try: + sim.create_world() + res = sim.add_robot(name="panda", data_config="panda") + if res["status"] != "success": + import pytest as _pytest + + _pytest.skip(f"panda not available: {res['content'][0]['text']}") + # IMPORTANT: do NOT call reset. T6 requires that add_robot itself leaves a clean state. + data = sim._world._data + assert np.allclose(data.qpos, 0.0), f"qpos should be zero after add_robot, got {data.qpos}" + assert np.allclose(data.qvel, 0.0), f"qvel should be zero after add_robot, got {data.qvel}" + assert np.allclose(data.ctrl, 0.0), f"ctrl should be zero after add_robot, got {data.ctrl}" + finally: + sim.destroy() + + +# --- T3: render camera strict validation ------------------------------- + + +class TestRenderCameraValidation: + def test_unknown_camera_errors(self, sim_with_world): + res = sim_with_world.render(camera_name="does_not_exist", width=64, height=48) + assert res["status"] == "error" + assert "not found" in res["content"][0]["text"] + + def test_default_camera_labelled_honestly(self, sim_with_world): + res = sim_with_world.render(camera_name="default", width=64, height=48) + if res["status"] != "success": + import pytest as _pytest + + _pytest.skip(f"offscreen render unavailable: {res['content'][0]['text']}") + assert "free (default)" in res["content"][0]["text"] + + def test_free_alias_labelled_honestly(self, sim_with_world): + res = sim_with_world.render(camera_name="free", width=64, height=48) + if res["status"] != "success": + import pytest as _pytest + + _pytest.skip(f"offscreen render unavailable: {res['content'][0]['text']}") + assert "free (default)" in res["content"][0]["text"] + + def test_render_depth_unknown_camera_errors(self, sim_with_world): + res = sim_with_world.render_depth(camera_name="ghost_cam", width=64, height=48) + assert res["status"] == "error" + assert "not found" in res["content"][0]["text"] + + +# --- T2: camera target actually applied ----------------------------- + + +class TestAddCameraTargetOrients: + """The 'headline broken feature': add_camera(target=...) was silently dropped + so every custom camera rendered the same default view. These tests verify + that orientation now flows through to the rendered pixels. + """ + + def _with_obj(self): + """Create a world with a distinguishable colored object for the cameras to frame.""" + sim = Simulation() + sim.create_world() + # Add a vivid red box at origin to make camera differences visible. + sim.add_object( + name="target_box", + shape="box", + size=[0.3, 0.3, 0.3], + position=[0.0, 0.0, 0.25], + color=[1.0, 0.0, 0.0, 1.0], + is_static=True, + ) + return sim + + def test_degenerate_target_equals_position_errors(self): + sim = self._with_obj() + try: + res = sim.add_camera(name="bad_cam", position=[1, 2, 3], target=[1, 2, 3]) + assert res["status"] == "error" + assert "identical" in res["content"][0]["text"] + finally: + sim.destroy() + + def test_wrong_length_position_errors(self): + sim = self._with_obj() + try: + res = sim.add_camera(name="bad_cam", position=[1, 2], target=[0, 0, 0]) + assert res["status"] == "error" + assert "3 elements" in res["content"][0]["text"] + finally: + sim.destroy() + + def test_camera_orientation_written(self): + """A target'd camera must end up with a non-default orientation in the + compiled model. Previously this test asserted on the raw ``xyaxes="..."`` + attribute in the scene XML, which the MjSpec builder path replaces with + a ``quat`` attribute. Both representations resolve to the same rotation + matrix in the compiled MjModel (``cam_mat0``) - which is what we + actually care about for rendering. + """ + import numpy as np + + sim = self._with_obj() + try: + res = sim.add_camera(name="side_cam", position=[2.0, 0.0, 0.3], target=[0.0, 0.0, 0.25]) + assert res["status"] == "success", res["content"][0]["text"] + + mj = sim._mj + model = sim._world._model + assert model is not None + cam_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, "side_cam") + assert cam_id >= 0, "camera was not registered in compiled model" + + # MuJoCo's default camera orientation is identity (looks along -Z). + # Our target->quat conversion for position [2, 0, 0.3] looking at + # [0, 0, 0.25] must produce a non-identity rotation. + rot = model.cam_mat0[cam_id].reshape(3, 3) + assert not np.allclose(rot, np.eye(3)), "camera has default (identity) orientation - target was ignored" + finally: + sim.destroy() + + def test_different_targets_produce_different_orientations(self): + """Two cameras at the SAME position but different targets must produce + DIFFERENT rotation matrices in the compiled MjModel. Before the + camera-target fix (T* in PR #85) both cameras shared MuJoCo's default + look direction, so rendered frames were identical regardless of the + ``target`` argument. + + We assert on ``cam_mat0`` (the rotation matrix of the camera frame + at qpos0) rather than rendered pixels, because offscreen GL on some + CI runners produces blank frames and makes pixel comparison + unreliable. cam_mat0 is representation-agnostic - works under both + legacy MJCFBuilder (xyaxes-based) and SpecBuilder (quat-based) paths. + """ + import numpy as np + + sim = self._with_obj() + try: + res_a = sim.add_camera(name="cam_a", position=[2.0, 0.0, 0.5], target=[0.0, 0.0, 0.25]) + res_b = sim.add_camera(name="cam_b", position=[2.0, 0.0, 0.5], target=[0.0, 2.0, 0.25]) + assert res_a["status"] == "success" + assert res_b["status"] == "success" + + mj = sim._mj + model = sim._world._model + a_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, "cam_a") + b_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_CAMERA, "cam_b") + assert a_id >= 0 and b_id >= 0 + + rot_a = model.cam_mat0[a_id].reshape(3, 3) + rot_b = model.cam_mat0[b_id].reshape(3, 3) + assert not np.allclose(rot_a, rot_b, atol=1e-3), ( + "cameras with different targets must have different orientations " + "(their cam_mat0 rotation matrices are currently identical, which means " + "`target` is being ignored)." + ) + finally: + sim.destroy() diff --git a/tests/simulation/mujoco/test_load_scene_interaction.py b/tests/simulation/mujoco/test_load_scene_interaction.py new file mode 100644 index 0000000..1278730 --- /dev/null +++ b/tests/simulation/mujoco/test_load_scene_interaction.py @@ -0,0 +1,297 @@ +"""Integration tests for ``load_scene`` interacting with downstream mutations. + +Regression suite for GH #115: ``load_scene`` previously did not populate +``_backend_state["xml"]`` / ``_backend_state["scene_loaded"]``, so subsequent +``add_object`` / ``add_camera`` / ``remove_object`` calls either: + +* recompiled the world via ``MJCFBuilder.build_objects_only``, silently + discarding every body/mesh from the loaded scene, or +* hit the XML round-trip path which fell through to ``mj_saveLastXML`` + global state and emitted the wrong (robot, not scene) XML. + +Each test here loads a scene, performs a mutation, and asserts the original +scene content survives and the mutation is reflected in the compiled model. +""" + +from __future__ import annotations + +import os +import tempfile +from collections.abc import Generator + +import pytest + +pytest.importorskip("mujoco") + +from strands_robots.simulation.mujoco.simulation import Simulation # noqa: E402 + +# Minimal scene: a ground plane + a named block body. This is *not* a robot - +# there are no joints/actuators/sensors. The original bug triggered when +# ``self._world.robots`` was empty, which is the case here. +SCENE_XML = """ + + +""" + + +@pytest.fixture +def scene_path() -> Generator[str, None, None]: + """Write the minimal scene XML to a temp file.""" + tmpdir = tempfile.mkdtemp(prefix="test_load_scene_") + path = os.path.join(tmpdir, "test_scene.xml") + with open(path, "w") as f: + f.write(SCENE_XML) + try: + yield path + finally: + import shutil + + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.fixture +def sim() -> Generator[Simulation, None, None]: + s = Simulation() + try: + yield s + finally: + s.cleanup() + + +def _world(sim: Simulation): + """Narrow `sim._world` from `SimWorld | None` to `SimWorld` for mypy. + + All tests here construct the world via load_scene / create_world before + inspecting state, so `sim._world` is definitely non-None at that point. + Wrap in this helper to keep assertions tidy. + """ + assert sim._world is not None + return sim._world + + +# ----------------------------------------------------------------------------- +# _backend_state population contract +# ----------------------------------------------------------------------------- + + +def test_load_scene_populates_backend_xml(sim: Simulation, scene_path: str) -> None: + """load_scene must cache the on-disk XML in _backend_state["xml"].""" + result = sim.load_scene(scene_path) + assert result["status"] == "success" + + stored = _world(sim)._backend_state.get("xml") + assert stored is not None, "scene XML must be cached for injection round-trip" + assert " None: + """load_scene must set the scene_loaded flag for downstream mutation gating.""" + sim.load_scene(scene_path) + assert _world(sim)._backend_state.get("scene_loaded") is True + + +def test_load_scene_records_scene_base_dir(sim: Simulation, scene_path: str) -> None: + """load_scene must record the scene's base dir for mesh path resolution.""" + sim.load_scene(scene_path) + base = _world(sim)._backend_state.get("scene_base_dir") + assert base is not None + assert os.path.isdir(base) + assert os.path.abspath(base) == os.path.dirname(os.path.abspath(scene_path)) + + +# ----------------------------------------------------------------------------- +# Scene survives downstream add_* mutations +# ----------------------------------------------------------------------------- + + +def test_add_object_after_load_scene_preserves_scene_bodies(sim: Simulation, scene_path: str) -> None: + """add_object after load_scene must inject via XML round-trip, not rebuild. + + The original bug: with no robots registered, add_object fell through to + _recompile_world() which called MJCFBuilder.build_objects_only - that + builder only knows about ``world.objects`` and rebuilt from scratch, + silently deleting every body from the loaded scene. + """ + sim.load_scene(scene_path) + mj = sim._mj + + # Establish baseline: the loaded scene has scene_block + scene_cylinder. + block_id_before = mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "scene_block") + cyl_id_before = mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "scene_cylinder") + assert block_id_before >= 0, "baseline: scene_block should exist in loaded scene" + assert cyl_id_before >= 0, "baseline: scene_cylinder should exist in loaded scene" + + # Now add an object. Bug: this used to wipe the scene. + result = sim.add_object(name="my_new_cube", shape="box", position=[0.0, 1.0, 0.1]) + assert result["status"] == "success", result + + # Loaded scene bodies must still exist. + block_id_after = mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "scene_block") + cyl_id_after = mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "scene_cylinder") + assert block_id_after >= 0, "scene_block was wiped by add_object (regression)" + assert cyl_id_after >= 0, "scene_cylinder was wiped by add_object (regression)" + + # And the newly added object must be in the model too. + # add_object injects a geom named '{name}_geom' under a body called '{name}'. + new_body_id = mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "my_new_cube") + assert new_body_id >= 0, "newly added object not found in compiled model" + + +def test_add_camera_after_load_scene_preserves_scene_bodies(sim: Simulation, scene_path: str) -> None: + """add_camera after load_scene must also use the XML round-trip path. + + Same failure mode as add_object: the ``else`` branch called + ``_recompile_world()`` which wiped the loaded scene. + """ + sim.load_scene(scene_path) + mj = sim._mj + + result = sim.add_camera(name="top_cam", position=[0.0, 0.0, 5.0], target=[0.0, 0.0, 0.0]) + assert result["status"] == "success", result + + # Scene bodies survive + assert mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "scene_block") >= 0 + assert mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "scene_cylinder") >= 0 + # Camera injected + assert mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_CAMERA, "top_cam") >= 0 + + +def test_remove_object_after_load_scene_preserves_other_bodies(sim: Simulation, scene_path: str) -> None: + """remove_object on a loaded-scene world must use ejection round-trip. + + Previously it called _recompile_world() and wiped everything except + ``world.objects`` (which is empty post-load_scene). + """ + sim.load_scene(scene_path) + # Add, then remove. Both mutations must preserve the loaded scene. + add_res = sim.add_object(name="temp_obj", shape="box", position=[0.5, 0.5, 0.5]) + assert add_res["status"] == "success" + + rm_res = sim.remove_object(name="temp_obj") + assert rm_res["status"] == "success", rm_res + + mj = sim._mj + # Loaded scene bodies survived the round-trip add + remove + assert mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "scene_block") >= 0 + assert mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "scene_cylinder") >= 0 + # temp_obj is gone + assert mj.mj_name2id(_world(sim)._model, mj.mjtObj.mjOBJ_BODY, "temp_obj") < 0, ( + "remove_object did not actually eject the body from the scene" + ) + + +def test_create_world_does_not_set_scene_loaded(sim: Simulation) -> None: + """create_world (the non-load_scene path) must leave scene_loaded unset. + + Regression guard: if create_world accidentally set the flag, add_object + would mistakenly try to inject into a scene it can freely rebuild, which + is slower and goes through more code paths. + """ + result = sim.create_world() + assert result["status"] == "success" + assert not _world(sim)._backend_state.get("scene_loaded", False) + + +# ----------------------------------------------------------------------------- +# load_scene + add_robot: the original scenario from the BRUTAL_REVIEW.md +# ----------------------------------------------------------------------------- + + +ROBOT_XML_FOR_INJECTION = """ + + + +""" + + +@pytest.fixture +def robot_for_injection_path() -> Generator[str, None, None]: + tmpdir = tempfile.mkdtemp(prefix="test_inject_robot_") + path = os.path.join(tmpdir, "inject_arm.xml") + with open(path, "w") as f: + f.write(ROBOT_XML_FOR_INJECTION) + try: + yield path + finally: + import shutil + + shutil.rmtree(tmpdir, ignore_errors=True) + + +def test_add_robot_after_load_scene_preserves_scene_and_robot( + sim: Simulation, scene_path: str, robot_for_injection_path: str +) -> None: + """Load a scene, then inject a robot. Scene bodies + robot joints survive. + + This is the exact scenario flagged in the second-opinion review: + + sim.load_scene(...) + sim.add_robot(...) + # Expected: scene bodies still there, robot is present + # Observed before fix: inject_robot_into_scene hits the + # stored_xml-is-None branch, mj_saveLastXML emits the wrong XML, + # and the merge breaks. + """ + # Step 1: load the scene. + res_scene = sim.load_scene(scene_path) + assert res_scene["status"] == "success" + + # Step 2: inject the robot. + res_robot = sim.add_robot(name="my_arm", urdf_path=robot_for_injection_path) + assert res_robot["status"] == "success", res_robot + + mj = sim._mj + model = _world(sim)._model + + # Scene bodies survive + assert mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "scene_block") >= 0, ( + "scene_block was lost after add_robot (regression)" + ) + assert mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "scene_cylinder") >= 0, ( + "scene_cylinder was lost after add_robot (regression)" + ) + + # Robot is namespaced under my_arm/ + # inject_robot_into_scene prefixes body/joint/actuator names with 'my_arm/' + assert mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "my_arm/arm_base") >= 0 + assert mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, "my_arm/arm_pan") >= 0 + + +def test_add_robot_then_add_object_after_load_scene( + sim: Simulation, scene_path: str, robot_for_injection_path: str +) -> None: + """Full chain: load_scene → add_robot → add_object → all survive.""" + sim.load_scene(scene_path) + assert sim.add_robot(name="my_arm", urdf_path=robot_for_injection_path)["status"] == "success" + assert sim.add_object(name="box_a", shape="box", position=[0.3, 0.3, 0.3])["status"] == "success" + assert sim.add_object(name="box_b", shape="box", position=[0.5, 0.5, 0.5])["status"] == "success" + + mj = sim._mj + model = _world(sim)._model + # All four things from all three sources coexist. + assert mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "scene_block") >= 0 + assert mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "my_arm/arm_base") >= 0 + assert mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "box_a") >= 0 + assert mj.mj_name2id(model, mj.mjtObj.mjOBJ_BODY, "box_b") >= 0 diff --git a/tests/simulation/mujoco/test_object_shapes.py b/tests/simulation/mujoco/test_object_shapes.py new file mode 100644 index 0000000..37c74e0 --- /dev/null +++ b/tests/simulation/mujoco/test_object_shapes.py @@ -0,0 +1,73 @@ +"""Every primitive shape supported by ``MJCFBuilder._object_xml`` must render. + +Also locks the scene-composer fallback path (``compose_multi_robot_scene``) +and the object-geom auto-naming convention (``_geom``). +""" + +from __future__ import annotations + +import os + +import pytest + +pytest.importorskip("mujoco") + +os.environ.setdefault("MUJOCO_GL", "glfw") + + +@pytest.fixture +def sim(): + from strands_robots.simulation import Simulation + + s = Simulation() + s.create_world() + yield s + s.destroy() + + +@pytest.mark.parametrize( + "shape,size,name", + [ + ("box", [0.02, 0.02, 0.02], "a_box"), + ("sphere", [0.025, 0.025, 0.025], "a_ball"), + ("cylinder", [0.02, 0.02, 0.06], "a_rod"), + ("capsule", [0.02, 0.02, 0.06], "a_capsule"), + ], +) +def test_primitive_shape_roundtrips_to_model(sim, shape, size, name): + r = sim.add_object(name=name, shape=shape, size=size, position=[0.1, 0.1, 0.05]) + assert r["status"] == "success", r + + # Geom is named by the convention '_geom' + import mujoco as mj + + gid = mj.mj_name2id(sim._world._model, mj.mjtObj.mjOBJ_GEOM, f"{name}_geom") + assert gid >= 0, f"geom '{name}_geom' not found in model" + + # And we can recolor it via geom_name (set_geom_properties coverage) + r = sim.set_geom_properties(geom_name=f"{name}_geom", color=[0.3, 0.3, 0.3, 1.0]) + assert r["status"] == "success" + + +def test_plane_object_auto_static(sim): + """T29: shape='plane' auto-sets is_static=True; add_object no longer + errors on plane shapes since they're now routed as static bodies + automatically.""" + r = sim.add_object(name="floor_mat", shape="plane", size=[0.5, 0.5, 0.001], position=[0, 0, 0.001]) + assert r["status"] == "success", r + assert sim._world.objects["floor_mat"].is_static is True + + +def test_plane_object_explicit_dynamic_rejected(sim): + """T29: Explicit is_static=False on a plane is a hard error - planes are + infinite and cannot be dynamic bodies in MuJoCo.""" + r = sim.add_object( + name="bad_floor", + shape="plane", + size=[0.5, 0.5, 0.001], + position=[0, 0, 0.001], + is_static=False, + ) + assert r["status"] == "error" + text = r["content"][0]["text"].lower() + assert "plane" in text and "is_static" in text diff --git a/tests/simulation/mujoco/test_physics.py b/tests/simulation/mujoco/test_physics.py new file mode 100644 index 0000000..ca847f2 --- /dev/null +++ b/tests/simulation/mujoco/test_physics.py @@ -0,0 +1,361 @@ +"""Tests for PhysicsMixin - advanced MuJoCo physics features. + +Tests: raycasting, jacobians, energy, forces, state checkpointing, +inverse dynamics, sensor readout, body introspection, runtime modification. + +Run: uv run pytest tests/test_physics.py -v +""" + +import json +import os + +import numpy as np +import pytest + +mj = pytest.importorskip("mujoco") + +from strands_robots.simulation.mujoco.simulation import Simulation # noqa: E402 + +ROBOT_XML = """ + + + +""" + + +@pytest.fixture +def sim(): + """Create a Simulation with the test scene loaded directly.""" + from strands_robots.simulation.models import SimStatus, SimWorld + + s = Simulation(tool_name="test_sim", mesh=False) + s._world = SimWorld() + s._world._model = mj.MjModel.from_xml_string(ROBOT_XML) + s._world._data = mj.MjData(s._world._model) + s._world.status = SimStatus.IDLE + mj.mj_forward(s._world._model, s._world._data) + yield s + s.cleanup() + + +def _extract_json_block(result, idx=1): + """Schema-tolerant: accepts both {"json": {...}} (new) and {"text": } (legacy). + + The content-block schema is in flux; this helper ensures tests work against either. + """ + block = result["content"][idx] + if "json" in block: + return block["json"] + return json.loads(block["text"]) + + +class TestRaycasting: + def test_raycast_hits_ground(self, sim): + result = sim.raycast(origin=[0, 0, 2], direction=[0, 0, -1]) + assert result["status"] == "success" + data = _extract_json_block(result, 1) + assert data["hit"] is True + assert data["distance"] is not None + assert data["distance"] > 0 + + def test_raycast_hits_box(self, sim): + result = sim.raycast(origin=[0, 0, 2], direction=[0, 0, -1]) + assert result["status"] == "success" + data = _extract_json_block(result, 1) + assert data["hit"] is True + assert data["geom_name"] in ("box_geom", "ground") + + def test_raycast_misses(self, sim): + result = sim.raycast(origin=[0, 0, 2], direction=[0, 0, 1]) # shooting up + assert result["status"] == "success" + data = _extract_json_block(result, 1) + assert data["hit"] is False + + def test_multi_raycast(self, sim): + dirs = [[0, 0, -1], [1, 0, 0], [0, 1, 0], [0, 0, 1]] + result = sim.multi_raycast(origin=[0, 0, 2], directions=dirs) + assert result["status"] == "success" + rays = _extract_json_block(result, 1)["rays"] + assert len(rays) == 4 + # At least the downward ray should hit + assert rays[0]["distance"] is not None + + +class TestJacobians: + def test_body_jacobian(self, sim): + result = sim.get_jacobian(body_name="link2") + assert result["status"] == "success" + data = _extract_json_block(result, 1) + assert len(data["jacp"]) == 3 # 3×nv + assert data["nv"] == sim._world._model.nv + + def test_site_jacobian(self, sim): + result = sim.get_jacobian(site_name="end_effector") + assert result["status"] == "success" + + def test_geom_jacobian(self, sim): + result = sim.get_jacobian(geom_name="link2_geom") + assert result["status"] == "success" + + def test_jacobian_no_target(self, sim): + result = sim.get_jacobian() + assert result["status"] == "error" + + def test_jacobian_invalid_body(self, sim): + result = sim.get_jacobian(body_name="nonexistent") + assert result["status"] == "error" + + +class TestEnergy: + def test_get_energy(self, sim): + result = sim.get_energy() + assert result["status"] == "success" + data = _extract_json_block(result, 1) + assert "potential" in data + assert "kinetic" in data + assert "total" in data + # Box at height 0.5 should have nonzero potential energy + assert data["potential"] != 0 or data["kinetic"] != 0 + + def test_energy_changes_after_step(self, sim): + e1 = _extract_json_block(sim.get_energy(), 1) + # Step physics to let box fall + for _ in range(100): + mj.mj_step(sim._world._model, sim._world._data) + e2 = _extract_json_block(sim.get_energy(), 1) + # Kinetic energy should change (box falls) + assert e1["kinetic"] != e2["kinetic"] or e1["potential"] != e2["potential"] + + +class TestExternalForces: + def test_apply_force(self, sim): + result = sim.apply_force(body_name="box1", force=[0, 0, 100]) + assert result["status"] == "success" + assert "box1" in result["content"][0]["text"] + + def test_apply_force_invalid_body(self, sim): + result = sim.apply_force(body_name="nonexistent", force=[0, 0, 10]) + assert result["status"] == "error" + + def test_force_changes_acceleration(self, sim): + # Get initial state + data = sim._world._data + old_qfrc = data.qfrc_applied.copy() + sim.apply_force(body_name="box1", force=[0, 0, 100]) + # qfrc_applied should change + assert not np.array_equal(old_qfrc, data.qfrc_applied) + + +class TestMassMatrix: + def test_get_mass_matrix(self, sim): + result = sim.get_mass_matrix() + assert result["status"] == "success" + data = _extract_json_block(result, 1) + nv = sim._world._model.nv + assert data["shape"] == [nv, nv] + assert data["rank"] > 0 + assert data["total_mass"] > 0 + + def test_mass_diagonal_positive(self, sim): + result = sim.get_mass_matrix() + diag = _extract_json_block(result, 1)["diagonal"] + assert all(d >= 0 for d in diag) + + +class TestStateCheckpointing: + def test_save_and_load_state(self, sim): + # Set a known joint position + sim._world._data.qpos[7] = 1.0 # shoulder + mj.mj_forward(sim._world._model, sim._world._data) + + # Save + result = sim.save_state(name="test_checkpoint") + assert result["status"] == "success" + + # Change state + sim._world._data.qpos[7] = -1.0 + mj.mj_forward(sim._world._model, sim._world._data) + assert sim._world._data.qpos[7] == pytest.approx(-1.0) + + # Restore + result = sim.load_state(name="test_checkpoint") + assert result["status"] == "success" + assert sim._world._data.qpos[7] == pytest.approx(1.0) + + def test_load_nonexistent_checkpoint(self, sim): + result = sim.load_state(name="doesnt_exist") + assert result["status"] == "error" + + +class TestInverseDynamics: + def test_inverse_dynamics(self, sim): + mj.mj_forward(sim._world._model, sim._world._data) + result = sim.inverse_dynamics() + assert result["status"] == "success" + forces = _extract_json_block(result, 1)["qfrc_inverse"] + assert "shoulder" in forces or "elbow" in forces + + +class TestBodyState: + def test_get_body_state(self, sim): + result = sim.get_body_state(body_name="box1") + assert result["status"] == "success" + state = _extract_json_block(result, 1) + assert "position" in state + assert "quaternion" in state + assert "linear_velocity" in state + assert "angular_velocity" in state + assert "mass" in state + assert len(state["position"]) == 3 + assert len(state["quaternion"]) == 4 + assert state["mass"] == pytest.approx(1.0) + + def test_body_state_invalid(self, sim): + result = sim.get_body_state(body_name="nonexistent") + assert result["status"] == "error" + + +class TestDirectJointControl: + def test_set_joint_positions(self, sim): + result = sim.set_joint_positions(positions={"shoulder": 0.5, "elbow": -0.3}) + assert result["status"] == "success" + assert "2/2" in result["content"][0]["text"] + + # Verify positions were set + model, data = sim._world._model, sim._world._data + shoulder_id = mj.mj_name2id(model, mj.mjtObj.mjOBJ_JOINT, "shoulder") + qpos_adr = model.jnt_qposadr[shoulder_id] + assert data.qpos[qpos_adr] == pytest.approx(0.5) + + def test_set_joint_velocities(self, sim): + result = sim.set_joint_velocities(velocities={"shoulder": 1.0}) + assert result["status"] == "success" + + +class TestSensors: + def test_get_all_sensors(self, sim): + result = sim.get_sensor_data() + assert result["status"] == "success" + sensors = _extract_json_block(result, 1)["sensors"] + assert "shoulder_pos" in sensors + assert "elbow_pos" in sensors + + def test_get_specific_sensor(self, sim): + result = sim.get_sensor_data(sensor_name="shoulder_pos") + assert result["status"] == "success" + sensors = _extract_json_block(result, 1)["sensors"] + assert len(sensors) == 1 + assert "shoulder_pos" in sensors + + def test_sensor_values_change(self, sim): + # Set shoulder position + sim.set_joint_positions(positions={"shoulder": 1.0}) + result = sim.get_sensor_data(sensor_name="shoulder_pos") + val = _extract_json_block(result, 1)["sensors"]["shoulder_pos"]["values"] + assert abs(val - 1.0) < 0.01 + + +class TestRuntimeModification: + def test_set_body_mass(self, sim): + result = sim.set_body_properties(body_name="box1", mass=5.0) + assert result["status"] == "success" + body_id = mj.mj_name2id(sim._world._model, mj.mjtObj.mjOBJ_BODY, "box1") + assert sim._world._model.body_mass[body_id] == pytest.approx(5.0) + + def test_set_geom_color(self, sim): + result = sim.set_geom_properties(geom_name="box_geom", color=[0, 1, 0, 1]) + assert result["status"] == "success" + geom_id = mj.mj_name2id(sim._world._model, mj.mjtObj.mjOBJ_GEOM, "box_geom") + assert sim._world._model.geom_rgba[geom_id][1] == pytest.approx(1.0) + + def test_set_geom_friction(self, sim): + result = sim.set_geom_properties(geom_name="box_geom", friction=[0.5, 0.01, 0.001]) + assert result["status"] == "success" + + def test_invalid_geom(self, sim): + result = sim.set_geom_properties(geom_name="nonexistent", color=[1, 0, 0, 1]) + assert result["status"] == "error" + + +class TestContactForces: + def test_get_contact_forces_after_settling(self, sim): + # Let box fall and settle + for _ in range(500): + mj.mj_step(sim._world._model, sim._world._data) + result = sim.get_contact_forces() + assert result["status"] == "success" + # Box should be in contact with ground + contacts = _extract_json_block(result, 1)["contacts"] + assert len(contacts) > 0 + assert contacts[0]["normal_force"] != 0 + + +class TestForwardKinematics: + def test_forward_kinematics(self, sim): + result = sim.forward_kinematics() + assert result["status"] == "success" + bodies = _extract_json_block(result, 1)["bodies"] + assert "box1" in bodies + assert "link1" in bodies + assert len(bodies["box1"]["position"]) == 3 + + +class TestTotalMass: + def test_get_total_mass(self, sim): + result = sim.get_total_mass() + assert result["status"] == "success" + data = _extract_json_block(result, 1) + assert data["total_mass"] > 0 + assert "box1" in data["bodies"] + assert data["bodies"]["box1"] == pytest.approx(1.0) + + +class TestExportXML: + def test_export_xml_string(self, sim): + result = sim.export_xml() + assert result["status"] == "success" + text = result["content"][0]["text"] + assert "mujoco" in text.lower() or "Model XML" in text + + def test_export_xml_file(self, sim, tmp_path): + path = str(tmp_path / "exported.xml") + result = sim.export_xml(output_path=path) + assert result["status"] == "success" + assert os.path.exists(path) + with open(path) as f: + content = f.read() + assert " + +