diff --git a/.gitignore b/.gitignore index 9dd6112..2e430c6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ build dist .strands_robots .coverage +.ideation/ diff --git a/AGENTS.md b/AGENTS.md index 3fa245d..8a66852 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -84,3 +84,19 @@ hatch run format # ruff check --fix, ruff format 4. Open PR from your fork, address all review comments 5. Track follow-up items as issues on the [project board](https://github.com/orgs/strands-labs/projects/2) 6. Squash merge into `main` + + +## Registry conventions (strands_robots/registry/robots.json) + +- **Flat asset paths** (e.g. `"model_xml": "scene.xml"`) are the common case. +- **Nested asset paths** (e.g. `"model_xml": "xmls/asimov.xml"`) are allowed when + the upstream source repo uses a subdir layout. Example: `asimov_v0` maps to + `asimovinc/asimov-v0` which has `sim-model/xmls/asimov.xml` + + `sim-model/assets/`. The `_safe_join` helper in `strands_robots/utils.py` + guards against traversal (`..`). +- **Auto-download strategy** — every robot with an `asset` block must declare + exactly one of: + 1. `asset.robot_descriptions_module` (preferred) + 2. `asset.source` with `type: "github"` + 3. `asset.auto_download: false` (explicit opt-out) + Enforced by `tests/test_registry_integrity.py`. diff --git a/README.md b/README.md index 4332e4c..0a93a93 100644 --- a/README.md +++ b/README.md @@ -486,6 +486,31 @@ while True: agent.tool.gr00t_inference(action="stop", port=8000) ``` +## Configuration + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (MJCF, meshes) | `~/.strands_robots/assets/` | +| `GROOT_API_TOKEN` | API token for GR00T inference service | — | + +### Cache Directory + +Robot model assets (MJCF XML files and meshes) are cached in: + +``` +~/.strands_robots/ +└── assets/ # Downloaded robot models (from robot_descriptions / MuJoCo Menagerie) + ├── trs_so_arm100/ + ├── franka_emika_panda/ + └── ... +``` + +To clear the cache: `rm -rf ~/.strands_robots/assets/` + +To change the cache location: `export STRANDS_ASSETS_DIR=/path/to/custom/dir` + ## Contributing We welcome contributions! Please see: diff --git a/pyproject.toml b/pyproject.toml index 40382ec..f1a7090 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,9 +48,13 @@ groot-service = [ lerobot = [ "lerobot>=0.5.0,<0.6.0", ] +sim = [ + "robot_descriptions>=1.11.0,<2.0.0", +] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", + "strands-robots[sim]", ] dev = [ "pytest>=6.0,<9.0.0", @@ -124,7 +128,7 @@ ignore_missing_imports = false # Third-party libs without type stubs [[tool.mypy.overrides]] -module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*"] +module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*"] ignore_missing_imports = true # @tool decorator injects runtime signatures mypy cannot check diff --git a/strands_robots/assets/__init__.py b/strands_robots/assets/__init__.py new file mode 100644 index 0000000..4e9080c --- /dev/null +++ b/strands_robots/assets/__init__.py @@ -0,0 +1,41 @@ +"""Robot Asset Manager for Strands Robots Simulation. + +Assets are resolved from ``robot_descriptions`` package or downloaded from +MuJoCo Menagerie GitHub, cached in ``~/.strands_robots/assets/``. +Override with ``STRANDS_ASSETS_DIR`` env var. + +Implementation lives in ``assets/manager.py`` — this file is thin exports only. +""" + +from strands_robots.assets.manager import ( + get_robot_info, + list_available_robots, + resolve_model_dir, + resolve_model_path, +) +from strands_robots.registry import ( + format_robot_table, + get_robot, + list_aliases, + list_robots, + list_robots_by_category, +) +from strands_robots.registry import ( + resolve_name as resolve_robot_name, +) +from strands_robots.utils import get_assets_dir, get_search_paths + +__all__ = [ + "resolve_model_path", + "resolve_model_dir", + "resolve_robot_name", + "get_robot_info", + "list_available_robots", + "list_robots_by_category", + "list_aliases", + "format_robot_table", + "get_assets_dir", + "get_search_paths", + "get_robot", + "list_robots", +] diff --git a/strands_robots/assets/download.py b/strands_robots/assets/download.py new file mode 100644 index 0000000..a6ffa7d --- /dev/null +++ b/strands_robots/assets/download.py @@ -0,0 +1,476 @@ +"""Download robot model assets via ``robot_descriptions`` or custom GitHub repos. + +This module contains the core download logic for robot assets. +The ``strands_robots.tools.download_assets`` tool is a thin ``@tool`` wrapper +that delegates to :func:`download_robots` here. + +Strategy (in order of preference): + 1. ``robot_descriptions`` package — recommended by MuJoCo Menagerie. + 2. Shallow ``git clone`` fallback for Menagerie robots. + 3. Custom GitHub repos for non-Menagerie robots. + +Assets are cached in ``~/.strands_robots/assets/`` (override with +``STRANDS_ASSETS_DIR``). Install the optional dependency:: + + pip install strands-robots[sim-mujoco] # includes robot_descriptions +""" + +from __future__ import annotations + +import importlib +import logging +import os +import re +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any + +from ..registry import get_robot +from ..registry import list_robots as registry_list_robots +from ..registry import resolve_name as resolve_robot_name +from ..utils import get_assets_dir, get_search_paths, safe_join + +logger = logging.getLogger(__name__) + +MENAGERIE_REPO = "https://github.com/google-deepmind/mujoco_menagerie.git" + +# Only HTTPS GitHub URLs are allowed for cloning. +_ALLOWED_CLONE_URL_RE = re.compile(r"^https://github\.com/[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+\.git$") + + +# ── robot_descriptions integration ──────────────────────────────────── + + +def _robot_descriptions_available() -> bool: + """Check if ``robot_descriptions`` is installed.""" + try: + import robot_descriptions # type: ignore[import-not-found] # noqa: F401 + + return True + except ImportError: + return False + + +def _resolve_robot_descriptions_module(name: str, info: dict) -> str | None: + """Resolve the ``robot_descriptions`` module name for a robot. + + Uses the ``robot_descriptions_module`` field from the registry (O(1)), + with a lightweight naming-convention fallback for unregistered robots. + + Args: + name: Canonical robot name. + info: Robot registry entry. + + Returns: + Module name (e.g. ``panda_mj_description``) or ``None``. + """ + asset = info.get("asset", {}) + + # Explicit opt-out: robot declares it has no robot_descriptions module + if asset.get("auto_download") is False: + return None + + # Primary: explicit registry entry (preferred, O(1)) + module_name: str | None = asset.get("robot_descriptions_module") + if module_name: + return str(module_name) + + # Fallback: try common naming conventions (max 3 imports) + asset_dir = info.get("asset", {}).get("dir", "") + candidates = [ + f"{asset_dir}_mj_description", + f"{name}_mj_description", + f"{name}_description", + ] + for candidate in candidates: + if not re.match(r"^[a-z0-9_]+$", candidate): + continue + try: + importlib.import_module(f"robot_descriptions.{candidate}") + logger.warning( + "Resolved '%s' via naming heuristic → '%s'. " + "Consider adding 'robot_descriptions_module' to the registry.", + name, + candidate, + ) + return candidate + except ImportError: + continue + + return None + + +# ── Helpers ─────────────────────────────────────────────────────────── + + +#: Alias for backward compatibility — use :func:`strands_robots.utils.get_assets_dir`. +get_user_assets_dir = get_assets_dir + + +def _needs_download(name: str, info: dict[str, Any] | None, force: bool = False) -> bool: + """Return *True* if a robot's mesh files are missing.""" + if info is None: + return False + asset = info.get("asset", {}) + if not asset: + return False + + xml_file, asset_dir = asset["model_xml"], asset["dir"] + + for search_dir in get_search_paths(): + model_path = search_dir / asset_dir / xml_file + if not model_path.exists(): + continue + try: + content = model_path.read_text() + mesh_files = re.findall(r'file="([^"]+\.(?:stl|STL|obj|OBJ|msh))"', content) + if not mesh_files: + return False + meshdir_match = re.search(r'meshdir="([^"]*)"', content) + meshdir = meshdir_match.group(1) if meshdir_match else "" + for mesh in mesh_files: + if not (model_path.parent / meshdir / mesh).exists(): + return True + return force + except Exception: + return True + + return True + + +def _get_source(info: dict[str, Any] | None) -> dict[str, Any]: + """Get download source for a robot. Defaults to ``menagerie``.""" + if info is None: + return {"type": "menagerie"} + source = info.get("asset", {}).get("source", {}) + return source if source else {"type": "menagerie"} + + +def _shallow_clone(repo_url: str, dest: str, *, timeout: int = 120) -> None: + """Shallow-clone *repo_url* into *dest*. + + Only HTTPS ``github.com`` URLs are accepted — ``ssh://``, ``git://``, + ``file://``, and other schemes are rejected to prevent command-injection + and SSRF risks. + + Raises: + ValueError: If *repo_url* does not match the allowed HTTPS GitHub pattern. + subprocess.CalledProcessError: If the ``git clone`` command fails. + subprocess.TimeoutExpired: If the clone exceeds *timeout* seconds. + """ + if not _ALLOWED_CLONE_URL_RE.match(repo_url): + raise ValueError(f"Blocked clone URL (only HTTPS github.com allowed): {repo_url!r}") + logger.info("Cloning %s (this may take a moment)...", repo_url) + subprocess.run( + ["git", "clone", "--depth", "1", repo_url, dest], + check=True, + capture_output=True, + timeout=timeout, + ) + + +# Filenames/patterns that are safe to strip from an upstream source tree before +# we copy it into the user's asset cache. Filtering at *copy* time (rather than +# deleting afterwards) means we never touch files that may already exist in *dst* +# — which matters when the user keeps notes/README alongside assets. +_COPY_CLEAN_SKIP = frozenset({"README.md", "LICENSE", "CHANGELOG.md"}) +_COPY_CLEAN_SUFFIX = (".png", ".jpg", ".jpeg") + + +def _copy_and_clean(src: Path, dst: Path) -> None: + """Copy *src* tree to *dst*, skipping non-essential files at copy time. + + Previous implementation deleted matching files from *dst* after copytree, + which meant a user's own ``README.md`` in the destination could be wiped. + This version filters on read so only files from *src* are dropped. + """ + + def _ignore(_dir: str, names: list[str]) -> list[str]: + return [ + n for n in names if n in _COPY_CLEAN_SKIP or n.lower().endswith(_COPY_CLEAN_SUFFIX) or n.startswith(".git") + ] + + shutil.copytree(str(src), str(dst), dirs_exist_ok=True, ignore=_ignore) + + +# ── Download backends ───────────────────────────────────────────────── + + +def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> dict[str, str]: + """Download robots using the ``robot_descriptions`` package. + + Imports only the specific module for each robot (O(1) per robot), + using the ``robot_descriptions_module`` field from the registry. + The import triggers the upstream clone on first use, then we symlink + ``PACKAGE_PATH`` into our asset cache. + """ + results: dict[str, str] = {} + if not robots: + return results + + for name, info in robots.items(): + asset_dir = info["asset"]["dir"] + module_name = _resolve_robot_descriptions_module(name, info) + if module_name is None: + results[name] = "skipped: no robot_descriptions module found" + continue + if not re.match(r"^[a-z0-9_]+$", module_name): + results[name] = f"skipped: invalid module name: {module_name}" + continue + + try: + mod = importlib.import_module(f"robot_descriptions.{module_name}") + package_path = Path(mod.PACKAGE_PATH) + if not package_path.exists(): + results[name] = f"failed: PACKAGE_PATH missing: {package_path}" + continue + + dst = safe_join(dest_dir, asset_dir) + if dst.is_symlink() and dst.resolve() == package_path.resolve(): + # Validate existing symlink still has the expected XML + expected_xml = dst / info["asset"]["model_xml"] + if expected_xml.exists(): + results[name] = "downloaded" + continue + # Stale symlink — remove and re-download via git + dst.unlink() + results[name] = f"failed: stale symlink — {info['asset']['model_xml']} not found in {package_path}" + continue + if dst.exists() or dst.is_symlink(): + dst.unlink() if dst.is_symlink() else shutil.rmtree(str(dst)) + + try: + dst.symlink_to(package_path) + except OSError: + shutil.copytree(str(package_path), str(dst), dirs_exist_ok=True) + + # Validate: expected XML must exist in the linked/copied dir + expected_xml = dst / info["asset"]["model_xml"] + if not expected_xml.exists(): + logger.warning( + "robot_descriptions module '%s' linked for %s but " + "expected XML '%s' not found — falling back to git", + module_name, + name, + info["asset"]["model_xml"], + ) + if dst.is_symlink(): + dst.unlink() + else: + shutil.rmtree(str(dst), ignore_errors=True) + results[name] = ( + f"failed: XML mismatch — module '{module_name}' does not contain {info['asset']['model_xml']}" + ) + continue + + results[name] = "downloaded" + except Exception as exc: + results[name] = f"failed: {exc}" + logger.warning("robot_descriptions failed for %s: %s", name, exc) + + return results + + +def _download_via_git(robots: dict[str, dict], dest_dir: Path) -> dict[str, str]: + """Fallback: shallow-clone Menagerie and copy robot directories.""" + results: dict[str, str] = {} + if not robots: + return results + + with tempfile.TemporaryDirectory() as tmpdir: + clone_dir = os.path.join(tmpdir, "mujoco_menagerie") + try: + _shallow_clone(MENAGERIE_REPO, clone_dir) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, ValueError) as exc: + reason = "timeout" if isinstance(exc, subprocess.TimeoutExpired) else str(exc)[:100] + return {n: f"failed: git clone {reason}" for n in robots} + + for name, info in robots.items(): + asset_dir = info["asset"]["dir"] + src = safe_join(Path(clone_dir), asset_dir) + if not src.exists(): + results[name] = f"failed: {asset_dir} not in menagerie" + continue + try: + _copy_and_clean(src, safe_join(dest_dir, asset_dir)) + results[name] = "downloaded" + except Exception as exc: + results[name] = f"failed: {exc}" + + return results + + +def _download_from_github(name: str, info: dict, dest_dir: Path) -> str: + """Download a robot from a custom GitHub repo (``asset.source``).""" + source = info["asset"]["source"] + repo = source["repo"] + if not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", repo): + return f"failed: invalid repo format: {repo}" + + subdir = source.get("subdir", "") + asset_dir = info["asset"]["dir"] + + with tempfile.TemporaryDirectory() as tmpdir: + clone_dir = os.path.join(tmpdir, "repo") + try: + # URL validation is enforced inside _shallow_clone itself + _shallow_clone(f"https://github.com/{repo}.git", clone_dir) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, ValueError) as exc: + reason = "timeout" if isinstance(exc, subprocess.TimeoutExpired) else str(exc)[:100] + return f"failed: git clone {reason}" + + src = Path(clone_dir) / subdir if subdir else Path(clone_dir) + if not src.exists(): + return f"failed: subdir '{subdir}' not found in {repo}" + + dst = safe_join(dest_dir, asset_dir) + try: + _copy_and_clean(src, dst) + return "downloaded" + except Exception as exc: + return f"failed: {exc}" + + +# ── Orchestrator ────────────────────────────────────────────────────── + + +def auto_download_robot(name: str, info: dict[str, Any]) -> bool: + """Auto-download a single robot's assets. + + Called by :func:`strands_robots.assets.manager.resolve_model_path` when + XML is present but meshes are missing. Tries ``robot_descriptions`` + first, then custom GitHub source if specified in the registry entry. + + Args: + name: Robot name (canonical or alias). + info: Registry entry for the robot. + + Returns: + ``True`` if a download attempt succeeded, ``False`` otherwise. + """ + dest_dir = get_assets_dir() + canonical = resolve_robot_name(name) + + # Try robot_descriptions first (covers most Menagerie robots) + if _robot_descriptions_available(): + results = _download_via_robot_descriptions({canonical: info}, dest_dir) + if results.get(canonical, "").startswith("downloaded"): + logger.info("Auto-downloaded %s via robot_descriptions", canonical) + return True + + # Fall back to custom GitHub source + source = info.get("asset", {}).get("source", {}) + if source.get("type") == "github": + result = _download_from_github(canonical, info, dest_dir) + if result.startswith("downloaded"): + logger.info("Auto-downloaded %s from GitHub", canonical) + return True + + return False + + +def download_robots( + names: list[str] | None = None, + category: str | None = None, + force: bool = False, +) -> dict[str, Any]: + """Download robot model assets from their respective sources. + + Strategy (in order of preference): + 1. ``robot_descriptions`` package — recommended by MuJoCo Menagerie. + 2. Shallow ``git clone`` fallback for Menagerie robots. + 3. Custom GitHub repos for non-Menagerie robots. + + Args: + names: Robot names to download (``None`` = all sim robots). + category: Filter by category (arm, humanoid, mobile, …). + force: Re-download even if present. + + Returns: + Dict with downloaded/skipped/failed counts, names, and details. + """ + dest_dir = get_user_assets_dir() + # Filter None values — get_robot() can return None for unknown names + all_sim: dict[str, dict[str, Any]] = { + r["name"]: info for r in registry_list_robots(mode="sim") if (info := get_robot(r["name"])) is not None + } + + # Resolve requested robots + if names: + robots: dict[str, dict[str, Any]] = {} + for name in names: + canonical = resolve_robot_name(name) + if canonical in all_sim: + robots[canonical] = all_sim[canonical] + else: + logger.warning("Unknown robot: %s (resolved: %s)", name, canonical) + elif category: + robots = {n: i for n, i in all_sim.items() if i.get("category") == category} + else: + robots = dict(all_sim) + + if not robots: + return {"downloaded": 0, "skipped": 0, "failed": 0, "message": "No matching robots found."} + + # Partition: needs download vs already present + to_download: dict[str, dict[str, Any]] = {} + skipped: list[str] = [] + for name, info in robots.items(): + if _needs_download(name, info, force): + to_download[name] = info + else: + skipped.append(name) + + if not to_download: + return { + "downloaded": 0, + "skipped": len(skipped), + "failed": 0, + "skipped_names": skipped, + "message": f"All {len(robots)} robots already have assets. Use force=True to re-download.", + } + + # Partition by source type + menagerie_robots: dict[str, Any] = {} + github_robots: dict[str, Any] = {} + for name, info in to_download.items(): + source = _get_source(info) + bucket = github_robots if source["type"] == "github" else menagerie_robots + bucket[name] = info + + # Download Menagerie robots (robot_descriptions → git fallback) + results: dict[str, str] = {} + if menagerie_robots: + if _robot_descriptions_available(): + results.update(_download_via_robot_descriptions(menagerie_robots, dest_dir)) + # Retry failures with git clone + retry = { + n: menagerie_robots[n] for n, r in results.items() if r.startswith("failed") or r.startswith("skipped") + } + if retry: + results.update(_download_via_git(retry, dest_dir)) + else: + results.update(_download_via_git(menagerie_robots, dest_dir)) + + # Download custom GitHub robots + for name, info in github_robots.items(): + results[name] = _download_from_github(name, info, dest_dir) + + downloaded = [n for n, r in results.items() if r == "downloaded"] + failed = {n: r for n, r in results.items() if r != "downloaded"} + method = "robot_descriptions" if _robot_descriptions_available() else "git clone" + + return { + "downloaded": len(downloaded), + "skipped": len(skipped), + "failed": len(failed), + "downloaded_names": downloaded, + "skipped_names": skipped, + "failed_names": list(failed), + "failed_details": failed, + "assets_dir": str(dest_dir), + "method": method, + "message": (f"{len(downloaded)} downloaded ({method}), {len(skipped)} already present, {len(failed)} failed."), + } diff --git a/strands_robots/assets/manager.py b/strands_robots/assets/manager.py new file mode 100644 index 0000000..f8a2b0b --- /dev/null +++ b/strands_robots/assets/manager.py @@ -0,0 +1,271 @@ +"""Robot Asset Manager for Strands Robots Simulation. + +Resolves robot model files (MJCF XML) from: + 1. ``STRANDS_ASSETS_DIR`` env var (user override) + 2. User cache (``~/.strands_robots/assets/``) + 3. ``robot_descriptions`` package (MuJoCo Menagerie) + 4. Project-local ``./assets/`` +""" + +import logging +import os +from pathlib import Path + +from strands_robots.registry import ( + get_robot, + list_robots, +) +from strands_robots.registry import ( + resolve_name as resolve_robot_name, +) +from strands_robots.utils import get_search_paths, safe_join + +logger = logging.getLogger(__name__) + +# Module-level conditional import — keeps manager.py importable in +# environments where the optional ``robot_descriptions`` package (and its +# transitive heavyweight deps like ``GitPython``) are not installed. +# When ``download`` is not available, auto-download simply returns False. +try: + from .download import auto_download_robot as _auto_download_robot_impl +except ImportError: + _auto_download_robot_impl = None # type: ignore[assignment] + + +# ───────────────────────────────────────────────────────────────────── +# Model path resolution (delegates to registry) +# ───────────────────────────────────────────────────────────────────── + + +def _auto_download_robot(name: str, info: dict) -> bool: + """Delegate to :func:`strands_robots.assets.download.auto_download_robot`. + + Returns ``False`` immediately when the download module is unavailable + (e.g. ``robot_descriptions`` not installed). + """ + if _auto_download_robot_impl is None: + logger.warning("Auto-download unavailable: install strands-robots[sim-mujoco] for automatic asset downloads") + return False + return _auto_download_robot_impl(name, info) + + +_MESH_EXTS = frozenset({".stl", ".obj", ".msh", ".ply"}) + +# Cache of (directory, mtime) -> has_meshes result. Avoids re-walking the tree +# when ``resolve_model_path`` checks multiple candidate locations for the same +# robot and when it re-checks after auto-download. +_MESH_CACHE: dict[tuple[str, float], bool] = {} + + +def _has_meshes(directory: Path) -> bool: + """Check if a directory tree contains mesh files (cached, early-exit). + + Uses ``os.scandir`` with an early break on the first mesh found rather + than ``rglob("*")``, which stats every file. Result is cached per + (directory, mtime) so repeated calls are free. + """ + if not directory.exists(): + return False + try: + cache_key = (str(directory), directory.stat().st_mtime) + except OSError: + cache_key = (str(directory), 0.0) + cached = _MESH_CACHE.get(cache_key) + if cached is not None: + return cached + + def _walk(path: str) -> bool: + try: + with os.scandir(path) as it: + for entry in it: + if entry.is_file(follow_symlinks=False): + ext = os.path.splitext(entry.name)[1].lower() + if ext in _MESH_EXTS: + return True + elif entry.is_dir(follow_symlinks=False) and _walk(entry.path): + return True + except OSError: + return False + return False + + result = _walk(str(directory)) + _MESH_CACHE[cache_key] = result + return result + + +def _resolve_candidates(asset_dir_name: str, xml_file: str, name: str) -> list[Path]: + """Resolve candidate paths for a robot XML, with path-traversal protection. + + Uses ``safe_join`` to prevent ``../`` in registry-sourced ``asset_dir_name`` + or ``xml_file`` from escaping the search directories. + """ + candidates: list[Path] = [] + for search_dir in get_search_paths(): + try: + model_path = safe_join(search_dir, f"{asset_dir_name}/{xml_file}") + except ValueError: + logger.warning("Path traversal attempt blocked for robot: %s", name) + return [] + if model_path.exists(): + candidates.append(model_path) + return candidates + + +def resolve_model_path( + name: str, + prefer_scene: bool = False, +) -> Path | None: + """Resolve a robot name to its MJCF model XML path. + + Looks up the robot in ``registry/robots.json``, then searches + the asset directories for the actual file. If XML is found but + mesh files are missing, automatically downloads them via + ``robot_descriptions`` before returning. + + Args: + name: Robot name (canonical or alias). + prefer_scene: If True, return scene XML (with ground/lights) + instead of bare model XML. + + Returns: + Path to the MJCF XML file, or None if not found. + + Examples:: + + resolve_model_path("so100") # → .../trs_so_arm100/so_arm100.xml + resolve_model_path("so100", prefer_scene=True) # → .../trs_so_arm100/scene.xml + resolve_model_path("franka") # → .../franka_emika_panda/panda.xml + """ + info = get_robot(name) + if not info or "asset" not in info: + logger.warning("Unknown robot or no asset: %s", name) + return None + + asset = info["asset"] + # Explicit str() casts: dict subscript returns Any, but Path / Any → Any + xml_file: str = str(asset["scene_xml"] if prefer_scene else asset["model_xml"]) + asset_dir_name: str = str(asset["dir"]) + + candidates: list[Path] = [] + + # Check user-registered asset path first (highest priority). + # ``xml_file`` comes from user_robots.json, so we still gate it through + # :func:`safe_join` to block path traversal even for user-authored entries + # (defense in depth — protects against a compromised user_robots.json and + # keeps the trust boundary identical to the built-in registry path). + user_path = info.get("_user_asset_path") + if user_path: + try: + user_model = safe_join(Path(user_path), xml_file) + except ValueError: + logger.warning( + "Path traversal blocked in _user_asset_path for %s: %r", + name, + xml_file, + ) + user_model = None + if user_model is not None and user_model.exists(): + candidates.append(user_model) + + # Search standard paths with traversal protection + candidates.extend(_resolve_candidates(asset_dir_name, xml_file, name)) + + if not candidates: + # No XML found at all — try auto-download, then re-search + logger.info("No XML found for %s, attempting auto-download...", name) + if _auto_download_robot(name, info): + candidates.extend(_resolve_candidates(asset_dir_name, xml_file, name)) + + if not candidates: + logger.warning("Robot model not found: %s → %s/%s", name, asset_dir_name, xml_file) + return None + + # Prefer the candidate whose directory contains mesh files, + # because an XML without meshes will fail to load in MuJoCo. + for path in candidates: + if _has_meshes(path.parent): + logger.debug("Resolved %s → %s (has meshes)", name, path) + return Path(path) + + # XML found but no meshes — auto-download and re-check + logger.info("XML found for %s but no meshes, attempting auto-download...", name) + if _auto_download_robot(name, info): + # Re-scan after download (new symlinks may have appeared) + refreshed = _resolve_candidates(asset_dir_name, xml_file, name) + for path in refreshed: + if _has_meshes(path.parent): + logger.debug("Resolved %s → %s (auto-downloaded)", name, path) + return Path(path) + + # Final fallback: return first candidate (some robots have no meshes) + logger.debug("Resolved %s → %s (no meshes available)", name, candidates[0]) + return Path(candidates[0]) + + +def resolve_model_dir(name: str) -> Path | None: + """Resolve a robot name to its asset directory (containing XML + meshes). + + Args: + name: Robot name (canonical or alias). + + Returns: + Path to the robot's asset directory, or None if not found. + """ + info = get_robot(name) + if not info or "asset" not in info: + return None + + asset_dir: str = str(info["asset"]["dir"]) + for search_dir in get_search_paths(): + try: + dir_path = safe_join(search_dir, asset_dir) + except ValueError: + logger.warning("Path traversal attempt blocked in resolve_model_dir: %s", asset_dir) + return None + if dir_path.exists(): + return Path(dir_path) + return None + + +def get_robot_info(name: str) -> dict | None: + """Get information about a robot model. + + Args: + name: Robot name (canonical or alias). + + Returns: + Dict with description, category, joints, asset info, etc. + """ + info = get_robot(name) + if info is None: + return None + result = dict(info) + result["canonical_name"] = resolve_robot_name(name) + path = resolve_model_path(name) + result["resolved_path"] = str(path) if path else None + result["available"] = path is not None + return result + + +def list_available_robots() -> list[dict]: + """List all available robot models with their info. + + Returns: + List of dicts with name, description, joints, category, available, path. + """ + robots = [] + for r in list_robots(mode="sim"): + path = resolve_model_path(r["name"]) + info = get_robot(r["name"]) or {} + robots.append( + { + "name": r["name"], + "description": r.get("description", ""), + "joints": r.get("joints"), + "category": r.get("category", ""), + "dir": info.get("asset", {}).get("dir", ""), + "available": path is not None, + "path": str(path) if path else None, + } + ) + return robots diff --git a/strands_robots/registry/__init__.py b/strands_robots/registry/__init__.py index 3430ce1..2d6ba3d 100644 --- a/strands_robots/registry/__init__.py +++ b/strands_robots/registry/__init__.py @@ -29,7 +29,7 @@ policies.json ← policy providers (shorthands/urls inside each entry) """ -from .loader import reload +from .loader import invalidate_cache, reload from .policies import ( build_policy_kwargs, get_policy_provider, @@ -48,6 +48,11 @@ list_robots_by_category, resolve_name, ) +from .user_registry import ( + list_user_robots, + register_robot, + unregister_robot, +) __all__ = [ # Robot registry @@ -66,6 +71,11 @@ "resolve_policy", "import_policy_class", "build_policy_kwargs", + # User-local registry + "register_robot", + "unregister_robot", + "list_user_robots", # Utilities "reload", + "invalidate_cache", ] diff --git a/strands_robots/registry/loader.py b/strands_robots/registry/loader.py index 188271d..99da64e 100644 --- a/strands_robots/registry/loader.py +++ b/strands_robots/registry/loader.py @@ -35,6 +35,11 @@ def _load(name: str) -> dict: if name not in _cache or _mtimes.get(name) != mtime: with open(path, encoding="utf-8") as f: data = json.load(f) + + # Merge user-local robot registry (overlay on top of package JSON) + if name == "robots": + data = _merge_user_robots(data) + _validate(name, data) _cache[name] = data _mtimes[name] = mtime @@ -43,6 +48,29 @@ def _load(name: str) -> dict: return _cache[name] +def _merge_user_robots(data: dict) -> dict: + """Merge user-local robot registry on top of package robots.json. + + User entries override package entries on name collision. + """ + try: + from .user_registry import get_user_robots + except ImportError: + return data + + user_robots = get_user_robots() + if not user_robots: + return data + + merged = dict(data) + merged_robots = dict(merged.get("robots", {})) + merged_robots.update(user_robots) + merged["robots"] = merged_robots + + logger.debug("Merged %d user-registered robot(s) into registry", len(user_robots)) + return merged + + def _validate(name: str, data: dict) -> None: """Validate uniqueness constraints after loading a registry file. @@ -103,3 +131,17 @@ def reload() -> None: """Force-reload all registry files (clears mtime cache).""" _cache.clear() _mtimes.clear() + + +def invalidate_cache(name: str | None = None) -> None: + """Invalidate cached registry data, forcing a reload on next access. + + Args: + name: Registry name to invalidate (e.g. "robots"). If None, clears all. + """ + if name is None: + _cache.clear() + _mtimes.clear() + else: + _cache.pop(name, None) + _mtimes.pop(name, None) diff --git a/strands_robots/registry/robots.json b/strands_robots/registry/robots.json index 5c552b2..97df780 100644 --- a/strands_robots/registry/robots.json +++ b/strands_robots/registry/robots.json @@ -1,48 +1,99 @@ { "robots": { - "so100": { - "description": "TrossenRobotics SO-ARM100 (6-DOF, Feetech servos)", - "category": "arm", - "joints": 13, + "crazyflie": { + "description": "Bitcraze Crazyflie 2 Nano-Quadcopter", + "category": "aerial", + "joints": 1, "asset": { - "dir": "trs_so_arm100", - "model_xml": "so_arm100.xml", + "dir": "bitcraze_crazyflie_2", + "model_xml": "cf2.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "so_arm100_mj_description" + "robot_descriptions_module": "cf2_mj_description" }, - "hardware": { - "lerobot_type": "so100_follower" - }, - "legacy_urdf": "so100/so100.urdf", "aliases": [ - "so100_4cam", - "so100_dualcam", - "so100_follower", - "so_arm100", - "trs_so_arm100" + "cf2", + "bitcraze_crazyflie" ] }, - "so101": { - "description": "RobotStudio SO-101 (6-DOF, upgraded SO-100)", + "skydio_x2": { + "description": "Skydio X2 Autonomous Drone", + "category": "aerial", + "joints": 1, + "asset": { + "dir": "skydio_x2", + "model_xml": "x2.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "skydio_x2_mj_description" + } + }, + "arx_l5": { + "description": "ARX L5 (6-DOF lightweight arm)", "category": "arm", - "joints": 9, + "joints": 11, "asset": { - "dir": "robotstudio_so101", - "model_xml": "so101.xml", - "scene_xml": "scene_box.xml", - "robot_descriptions_module": "so_arm101_mj_description" + "dir": "arx_l5", + "model_xml": "arx_l5.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "arx_l5_mj_description" + } + }, + "dynamixel_2r": { + "description": "Dynamixel 2R Educational Arm (2-DOF)", + "category": "arm", + "joints": 2, + "asset": { + "dir": "dynamixel_2r", + "model_xml": "dynamixel_2r.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "dynamixel_2r_mj_description" + } + }, + "fr3": { + "description": "Franka Research 3 (7-DOF + gripper)", + "category": "arm", + "joints": 8, + "asset": { + "dir": "franka_fr3", + "model_xml": "fr3.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "fr3_mj_description" }, - "hardware": { - "lerobot_type": "so101_follower" + "aliases": [ + "franka_fr3" + ] + }, + "fr3_v2": { + "description": "Franka Research 3 v2 (7-DOF + gripper, updated)", + "category": "arm", + "joints": 7, + "asset": { + "dir": "franka_fr3_v2", + "model_xml": "fr3v2.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "fr3_v2_mj_description" }, "aliases": [ - "robotstudio_so101", - "so101_dualcam", - "so101_follower", - "so101_leader", - "so101_tricam" + "franka_fr3_v2" ] }, + "hope_jr": { + "description": "Hope Junior arm", + "category": "arm", + "hardware": { + "lerobot_type": "hope_jr" + } + }, + "kinova_gen3": { + "description": "Kinova Gen3 (7-DOF lightweight)", + "category": "arm", + "joints": 7, + "asset": { + "dir": "kinova_gen3", + "model_xml": "gen3.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "gen3_mj_description" + } + }, "koch": { "description": "Koch v1.1 Low Cost Robot Arm (6-DOF, Dynamixel)", "category": "arm", @@ -62,6 +113,48 @@ "low_cost_robot_arm" ] }, + "kuka_iiwa": { + "description": "KUKA LBR iiwa 14 (7-DOF collaborative)", + "category": "arm", + "joints": 11, + "asset": { + "dir": "kuka_iiwa_14", + "model_xml": "iiwa14.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "iiwa14_mj_description" + }, + "aliases": [ + "kuka_iiwa_14" + ] + }, + "omx": { + "description": "OMX Robot Arm (ROBOTIS, CAN bus motors)", + "category": "arm", + "hardware": { + "lerobot_type": "omx" + }, + "aliases": [ + "omx_follower", + "omx_robot", + "robotis_omx" + ] + }, + "openarm": { + "description": "Enactic OpenArm (7-DOF, DAMIAO motors, CAN bus)", + "category": "arm", + "joints": 9, + "asset": { + "dir": "enactic_openarm", + "model_xml": "openarm.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "openarm_v1_mj_description" + }, + "aliases": [ + "enactic_openarm", + "open_arm", + "openarm_v10" + ] + }, "panda": { "description": "Franka Emika Panda (7-DOF + gripper)", "category": "arm", @@ -84,70 +177,98 @@ "single_panda_gripper" ] }, - "fr3": { - "description": "Franka Research 3 (7-DOF + gripper)", + "piper": { + "description": "AgileX Piper (6-DOF + gripper)", "category": "arm", - "joints": 8, + "joints": 11, "asset": { - "dir": "franka_fr3", - "model_xml": "fr3.xml", + "dir": "agilex_piper", + "model_xml": "piper.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "fr3_mj_description" + "robot_descriptions_module": "piper_mj_description" }, "aliases": [ - "franka_fr3", - "franka_fr3_v2" + "agilex_piper" ] }, - "ur5e": { - "description": "Universal Robots UR5e (6-DOF industrial)", + "sawyer": { + "description": "Rethink Robotics Sawyer (7-DOF)", "category": "arm", - "joints": 8, + "joints": 7, "asset": { - "dir": "universal_robots_ur5e", - "model_xml": "ur5e.xml", + "dir": "rethink_robotics_sawyer", + "model_xml": "sawyer.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "ur5e_mj_description" - } + "robot_descriptions_module": "sawyer_mj_description" + }, + "aliases": [ + "rethink_sawyer" + ] }, - "kuka_iiwa": { - "description": "KUKA LBR iiwa 14 (7-DOF collaborative)", + "so100": { + "description": "TrossenRobotics SO-ARM100 (6-DOF, Feetech servos)", "category": "arm", - "joints": 11, + "joints": 13, "asset": { - "dir": "kuka_iiwa_14", - "model_xml": "iiwa14.xml", + "dir": "trs_so_arm100", + "model_xml": "so_arm100.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "iiwa14_mj_description" + "robot_descriptions_module": "so_arm100_mj_description" + }, + "hardware": { + "lerobot_type": "so100_follower" }, + "legacy_urdf": "so100/so100.urdf", "aliases": [ - "kuka_iiwa_14" + "so100_4cam", + "so100_dualcam", + "so100_follower", + "so_arm100", + "trs_so_arm100" ] }, - "kinova_gen3": { - "description": "Kinova Gen3 (7-DOF lightweight)", + "so101": { + "description": "RobotStudio SO-101 (6-DOF, upgraded SO-100)", "category": "arm", - "joints": 7, + "joints": 9, "asset": { - "dir": "kinova_gen3", - "model_xml": "gen3.xml", + "dir": "robotstudio_so101", + "model_xml": "so101.xml", + "scene_xml": "scene_box.xml", + "robot_descriptions_module": "so_arm101_mj_description" + }, + "hardware": { + "lerobot_type": "so101_follower" + }, + "aliases": [ + "robotstudio_so101", + "so101_dualcam", + "so101_follower", + "so101_leader", + "so101_tricam" + ] + }, + "ur10e": { + "description": "Universal Robots UR10e (6-DOF industrial)", + "category": "arm", + "joints": 6, + "asset": { + "dir": "universal_robots_ur10e", + "model_xml": "ur10e.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "gen3_mj_description" + "robot_descriptions_module": "ur10e_mj_description" } }, - "xarm7": { - "description": "UFactory xArm 7 (7-DOF + gripper)", + "ur5e": { + "description": "Universal Robots UR5e (6-DOF industrial)", "category": "arm", - "joints": 13, + "joints": 8, "asset": { - "dir": "ufactory_xarm7", - "model_xml": "xarm7.xml", + "dir": "universal_robots_ur5e", + "model_xml": "ur5e.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "xarm7_mj_description" - }, - "aliases": [ - "ufactory_xarm7" - ] + "robot_descriptions_module": "ur5e_mj_description" + } }, "vx300s": { "description": "Trossen ViperX 300s (6-DOF + gripper)", @@ -165,59 +286,61 @@ "viper_x300s" ] }, - "arx_l5": { - "description": "ARX L5 (6-DOF lightweight arm)", + "wx250s": { + "description": "Trossen WidowX 250s (6-DOF + gripper)", "category": "arm", - "joints": 11, + "joints": 16, "asset": { - "dir": "arx_l5", - "model_xml": "arx_l5.xml", + "dir": "trossen_wx250s", + "model_xml": "wx250s.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "arx_l5_mj_description" - } + "robot_descriptions_module": "widow_mj_description" + }, + "aliases": [ + "widowx_250s", + "trossen_wx250s" + ] }, - "piper": { - "description": "AgileX Piper (6-DOF + gripper)", + "xarm7": { + "description": "UFactory xArm 7 (7-DOF + gripper)", "category": "arm", - "joints": 11, + "joints": 13, "asset": { - "dir": "agilex_piper", - "model_xml": "piper.xml", + "dir": "ufactory_xarm7", + "model_xml": "xarm7.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "piper_mj_description" + "robot_descriptions_module": "xarm7_mj_description" }, "aliases": [ - "agilex_piper" + "ufactory_xarm7" ] }, - "z1": { - "description": "Unitree Z1 (6-DOF + gripper)", + "yam": { + "description": "i2rt YAM Arm (8-DOF)", "category": "arm", "joints": 8, "asset": { - "dir": "unitree_z1", - "model_xml": "z1_gripper.xml", + "dir": "i2rt_yam", + "model_xml": "yam.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "z1_mj_description" + "robot_descriptions_module": "yam_mj_description" }, "aliases": [ - "unitree_z1" + "i2rt_yam" ] }, - "openarm": { - "description": "Enactic OpenArm (7-DOF, DAMIAO motors, CAN bus)", + "z1": { + "description": "Unitree Z1 (6-DOF + gripper)", "category": "arm", - "joints": 9, + "joints": 8, "asset": { - "dir": "enactic_openarm", - "model_xml": "openarm.xml", + "dir": "unitree_z1", + "model_xml": "z1_gripper.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "openarm_v1_mj_description" + "robot_descriptions_module": "z1_mj_description" }, "aliases": [ - "enactic_openarm", - "open_arm", - "openarm_v10" + "unitree_z1" ] }, "aloha": { @@ -243,6 +366,18 @@ "galaxea_r1_pro" ] }, + "bi_openarm": { + "description": "Bi-manual OpenArm (dual-arm coordination)", + "category": "bimanual", + "hardware": { + "lerobot_type": "bi_openarm" + }, + "aliases": [ + "bi_openarm_follower", + "dual_openarm", + "openarm_bimanual" + ] + }, "trossen_wxai": { "description": "Trossen WidowX AI Bimanual", "category": "bimanual", @@ -251,69 +386,344 @@ "dir": "trossen_wxai", "model_xml": "trossen_ai_bimanual.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "widow_mj_description" + "auto_download": false + }, + "aliases": [ + "trossen_ai_bimanual" + ] + }, + "reachy_mini": { + "description": "Pollen Reachy Mini (6-DOF Stewart head + antennas, 9 actuators)", + "category": "expressive", + "joints": 21, + "asset": { + "dir": "reachy_mini", + "model_xml": "mjcf/reachy_mini.xml", + "scene_xml": "mjcf/scene.xml", + "source": { + "type": "github", + "repo": "pollen-robotics/reachy_mini", + "subdir": "src/reachy_mini/descriptions/reachy_mini" + } + }, + "aliases": [ + "pollen_reachy_mini", + "reachy", + "reachy-mini", + "reachymini" + ] + }, + "ability_hand": { + "description": "PSYONIC Ability Hand (5-finger prosthetic, 11-DOF)", + "category": "hand", + "joints": 11, + "asset": { + "dir": "mujoco_xml", + "model_xml": "scene.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "ability_hand_mj_description" + }, + "aliases": [ + "psyonic_ability_hand" + ] + }, + "aero_hand": { + "description": "Tetheria Aero Hand Open (16-DOF dexterous)", + "category": "hand", + "joints": 16, + "asset": { + "dir": "tetheria_aero_hand_open", + "model_xml": "left_hand.xml", + "scene_xml": "scene_left.xml", + "robot_descriptions_module": "aero_hand_open_mj_description" + }, + "aliases": [ + "tetheria_aero_hand", + "aero_hand_open" + ] + }, + "allegro_hand": { + "description": "Wonik Allegro Hand (16-DOF dexterous)", + "category": "hand", + "joints": 16, + "asset": { + "dir": "wonik_allegro", + "model_xml": "left_hand.xml", + "scene_xml": "scene_left.xml", + "robot_descriptions_module": "allegro_hand_mj_description" + }, + "aliases": [ + "wonik_allegro" + ] + }, + "leap_hand": { + "description": "LEAP Hand (16-DOF dexterous)", + "category": "hand", + "joints": 41, + "asset": { + "dir": "leap_hand", + "model_xml": "left_hand.xml", + "scene_xml": "scene_left.xml", + "robot_descriptions_module": "leap_hand_mj_description" + } + }, + "robotiq_2f85": { + "description": "Robotiq 2F-85 Gripper (2-finger adaptive)", + "category": "hand", + "joints": 16, + "asset": { + "dir": "robotiq_2f85", + "model_xml": "2f85.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "robotiq_2f85_mj_description" + }, + "aliases": [ + "robotiq" + ] + }, + "robotiq_2f85_v4": { + "description": "Robotiq 2F-85 v4 Gripper (updated model)", + "category": "hand", + "joints": 6, + "asset": { + "dir": "robotiq_2f85_v4", + "model_xml": "2f85.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "robotiq_2f85_v4_mj_description" + } + }, + "shadow_dexee": { + "description": "Shadow DexEE Dexterous End-Effector (12-DOF)", + "category": "hand", + "joints": 12, + "asset": { + "dir": "shadow_dexee", + "model_xml": "shadow_dexee.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "shadow_dexee_mj_description" + } + }, + "shadow_hand": { + "description": "Shadow Dexterous Hand (24-DOF)", + "category": "hand", + "joints": 45, + "asset": { + "dir": "shadow_hand", + "model_xml": "left_hand.xml", + "scene_xml": "scene_left.xml", + "robot_descriptions_module": "shadow_hand_mj_description" + } + }, + "adam_lite": { + "description": "PNDbotics Adam Lite Humanoid (26-DOF)", + "category": "humanoid", + "joints": 26, + "asset": { + "dir": "pndbotics_adam_lite", + "model_xml": "adam_lite.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "adam_lite_mj_description" + }, + "aliases": [ + "pndbotics_adam_lite" + ] + }, + "apollo": { + "description": "Apptronik Apollo Humanoid (34-DOF)", + "category": "humanoid", + "joints": 34, + "asset": { + "dir": "apptronik_apollo", + "model_xml": "apptronik_apollo.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "apollo_mj_description" + }, + "aliases": [ + "apptronik_apollo" + ] + }, + "asimov_v0": { + "description": "Asimov V0 Bipedal Legs (12-DOF + 2 passive toes)", + "category": "humanoid", + "joints": 15, + "asset": { + "dir": "asimov_v0", + "model_xml": "xmls/asimov.xml", + "scene_xml": "xmls/asimov.xml", + "source": { + "type": "github", + "repo": "asimovinc/asimov-v0", + "subdir": "sim-model" + } + }, + "aliases": [ + "asimov" + ] + }, + "booster_t1": { + "description": "Booster T1 Humanoid (24-DOF)", + "category": "humanoid", + "joints": 24, + "asset": { + "dir": "booster_t1", + "model_xml": "t1.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "booster_t1_mj_description" + } + }, + "cassie": { + "description": "Agility Cassie Bipedal Robot", + "category": "humanoid", + "joints": 28, + "asset": { + "dir": "agility_cassie", + "model_xml": "cassie.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "cassie_mj_description" + }, + "aliases": [ + "agility_cassie" + ] + }, + "elf2": { + "description": "BXI Elf2 Humanoid (25-DOF)", + "category": "humanoid", + "joints": 26, + "asset": { + "dir": "xml", + "model_xml": "elf2_dof25.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "elf2_mj_description" + }, + "aliases": [ + "bxi_elf2" + ] + }, + "fourier_n1": { + "description": "Fourier N1 / GR-1 Humanoid (26-DOF)", + "category": "humanoid", + "joints": 26, + "asset": { + "dir": "fourier_n1", + "model_xml": "n1.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "n1_mj_description" + }, + "aliases": [ + "fourier_gr1", + "fourier_gr1_arms_only", + "fourier_gr1_arms_waist", + "fourier_gr1_full_upper_body", + "gr1" + ] + }, + "jvrc": { + "description": "JVRC-1 Humanoid (HRP-based, 45-DOF)", + "category": "humanoid", + "joints": 45, + "asset": { + "dir": "jvrc_mj_description", + "model_xml": "xml/jvrc1.xml", + "scene_xml": "xml/jvrc1.xml", + "robot_descriptions_module": "jvrc_mj_description" + }, + "aliases": [ + "jvrc1" + ] + }, + "op3": { + "description": "ROBOTIS OP3 Humanoid (20-DOF)", + "category": "humanoid", + "joints": 21, + "asset": { + "dir": "robotis_op3", + "model_xml": "op3.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "op3_mj_description" + }, + "aliases": [ + "robotis_op3" + ] + }, + "open_duck_mini": { + "description": "Open Duck Mini V2 (16-DOF expressive biped, Feetech servos)", + "category": "humanoid", + "joints": 16, + "asset": { + "dir": "open_duck_mini_v2", + "model_xml": "open_duck_mini_v2.xml", + "scene_xml": "scene.xml", + "source": { + "type": "github", + "repo": "apirrone/Open_Duck_Mini", + "subdir": "mini_bdx/robots/open_duck_mini_v2" + } }, "aliases": [ - "trossen_ai_bimanual" + "bdx", + "mini_bdx", + "open_duck", + "open_duck_mini_v2", + "open_duck_v2" ] }, - "shadow_hand": { - "description": "Shadow Dexterous Hand (24-DOF)", - "category": "hand", - "joints": 45, + "rby1": { + "description": "Rainbow Robotics RB-Y1A Mobile Manipulator (31-DOF)", + "category": "humanoid", + "joints": 31, "asset": { - "dir": "shadow_hand", - "model_xml": "left_hand.xml", - "scene_xml": "scene_left.xml", - "robot_descriptions_module": "shadow_hand_mj_description" + "dir": "mujoco", + "model_xml": "model.xml", + "scene_xml": "model.xml", + "robot_descriptions_module": "rby1_mj_description" }, "aliases": [ - "shadow_dexee" + "rby1a", + "rainbow_rby1" ] }, - "leap_hand": { - "description": "LEAP Hand (16-DOF dexterous)", - "category": "hand", - "joints": 41, - "asset": { - "dir": "leap_hand", - "model_xml": "left_hand.xml", - "scene_xml": "scene_left.xml", - "robot_descriptions_module": "leap_hand_mj_description" + "reachy2": { + "description": "Pollen Reachy 2", + "category": "humanoid", + "hardware": { + "lerobot_type": "reachy2" } }, - "robotiq_2f85": { - "description": "Robotiq 2F-85 Gripper (2-finger adaptive)", - "category": "hand", - "joints": 16, + "talos": { + "description": "PAL Robotics TALOS Humanoid (32-DOF)", + "category": "humanoid", + "joints": 45, "asset": { - "dir": "robotiq_2f85", - "model_xml": "2f85.xml", + "dir": "pal_talos", + "model_xml": "talos.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "robotiq_2f85_mj_description" + "robot_descriptions_module": "talos_mj_description" }, "aliases": [ - "robotiq", - "robotiq_2f85_v4" + "pal_talos" ] }, - "fourier_n1": { - "description": "Fourier N1 / GR-1 Humanoid (26-DOF)", + "toddlerbot_2xc": { + "description": "Toddlerbot 2xC Humanoid (45-DOF)", "category": "humanoid", - "joints": 26, + "joints": 45, "asset": { - "dir": "fourier_n1", - "model_xml": "n1.xml", + "dir": "toddlerbot_2xc", + "model_xml": "toddlerbot_2xc.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "n1_mj_description" - }, - "aliases": [ - "fourier_gr1", - "fourier_gr1_arms_only", - "fourier_gr1_arms_waist", - "fourier_gr1_full_upper_body", - "gr1" - ] + "robot_descriptions_module": "toddlerbot_2xc_mj_description" + } + }, + "toddlerbot_2xm": { + "description": "Toddlerbot 2xM Humanoid (45-DOF)", + "category": "humanoid", + "joints": 45, + "asset": { + "dir": "toddlerbot_2xm", + "model_xml": "toddlerbot_2xm.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "toddlerbot_2xm_mj_description" + } }, "unitree_g1": { "description": "Unitree G1 Humanoid (29-DOF + dexterous hands)", @@ -351,121 +761,107 @@ "h1" ] }, - "apollo": { - "description": "Apptronik Apollo Humanoid (34-DOF)", + "unitree_h1_2": { + "description": "Unitree H1-2 Humanoid (52-DOF, with hands)", "category": "humanoid", - "joints": 34, + "joints": 52, "asset": { - "dir": "apptronik_apollo", - "model_xml": "apptronik_apollo.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "apollo_mj_description" + "dir": "h1_2_description", + "model_xml": "h1_2.xml", + "scene_xml": "h1_2.xml", + "robot_descriptions_module": "h1_2_mj_description" }, "aliases": [ - "apptronik_apollo" + "h1_2" ] }, - "cassie": { - "description": "Agility Cassie Bipedal Robot", - "category": "humanoid", - "joints": 28, + "aliengo": { + "description": "Unitree Aliengo Quadruped (12-DOF)", + "category": "mobile", + "joints": 13, "asset": { - "dir": "agility_cassie", - "model_xml": "cassie.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "cassie_mj_description" + "dir": "aliengo", + "model_xml": "xml/aliengo.xml", + "scene_xml": "xml/aliengo.xml", + "robot_descriptions_module": "aliengo_mj_description" }, "aliases": [ - "agility_cassie" + "unitree_aliengo" ] }, - "open_duck_mini": { - "description": "Open Duck Mini V2 (16-DOF expressive biped, Feetech servos)", - "category": "humanoid", - "joints": 16, + "anymal_b": { + "description": "ANYbotics ANYmal B Quadruped (12-DOF)", + "category": "mobile", + "joints": 13, "asset": { - "dir": "open_duck_mini_v2", - "model_xml": "open_duck_mini_v2.xml", + "dir": "anybotics_anymal_b", + "model_xml": "anymal_b.xml", "scene_xml": "scene.xml", - "source": { - "type": "github", - "repo": "apirrone/Open_Duck_Mini", - "subdir": "mini_bdx/robots/open_duck_mini_v2" - } + "robot_descriptions_module": "anymal_b_mj_description" }, "aliases": [ - "bdx", - "mini_bdx", - "open_duck", - "open_duck_mini_v2", - "open_duck_v2" + "anybotics_anymal_b" ] }, - "asimov_v0": { - "description": "Asimov V0 Bipedal Legs (12-DOF + 2 passive toes)", - "category": "humanoid", - "joints": 15, + "anymal_c": { + "description": "ANYbotics ANYmal C Quadruped (12-DOF)", + "category": "mobile", + "joints": 13, "asset": { - "dir": "asimov_v0", - "model_xml": "asimov_v0.xml", + "dir": "anybotics_anymal_c", + "model_xml": "anymal_c.xml", "scene_xml": "scene.xml", - "source": { - "type": "github", - "repo": "asimovinc/asimov-v0", - "subdir": "sim-model" - } + "robot_descriptions_module": "anymal_c_mj_description" }, "aliases": [ - "asimov" + "anybotics_anymal_c" ] }, - "reachy_mini": { - "description": "Pollen Reachy Mini (6-DOF Stewart head + antennas, 9 actuators)", - "category": "expressive", - "joints": 21, - "asset": { - "dir": "reachy_mini", - "model_xml": "mjcf/reachy_mini.xml", - "scene_xml": "mjcf/scene.xml", - "source": { - "type": "github", - "repo": "pollen-robotics/reachy_mini", - "subdir": "src/reachy_mini/descriptions/reachy_mini" - } + "earthrover": { + "description": "EarthRover Mini Plus (mobile outdoor navigation)", + "category": "mobile", + "hardware": { + "lerobot_type": "earthrover" }, "aliases": [ - "pollen_reachy_mini", - "reachy", - "reachy-mini", - "reachymini" + "earth_rover", + "earthrover_mini_plus", + "frodobots" ] }, - "unitree_go2": { - "description": "Unitree Go2 Quadruped", + "go1": { + "description": "Unitree Go1 Quadruped (12-DOF)", "category": "mobile", - "joints": 40, + "joints": 13, "asset": { - "dir": "unitree_go2", - "model_xml": "go2.xml", + "dir": "unitree_go1", + "model_xml": "go1.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "go2_mj_description" + "robot_descriptions_module": "go1_mj_description" }, "aliases": [ - "go2" + "unitree_go1" ] }, - "unitree_a1": { - "description": "Unitree A1 Quadruped", + "lekiwi": { + "description": "LeKiwi mobile robot", "category": "mobile", - "joints": 16, + "hardware": { + "lerobot_type": "lekiwi" + } + }, + "robot_soccer_kit": { + "description": "Robot Soccer Kit (multi-robot soccer, 65-DOF total)", + "category": "mobile", + "joints": 65, "asset": { - "dir": "unitree_a1", - "model_xml": "a1.xml", + "dir": "robot_soccer_kit", + "model_xml": "robot_soccer_kit.xml", "scene_xml": "scene.xml", - "robot_descriptions_module": "a1_mj_description" + "robot_descriptions_module": "rsk_mj_description" }, "aliases": [ - "a1" + "rsk" ] }, "spot": { @@ -482,6 +878,20 @@ "boston_dynamics_spot" ] }, + "stretch": { + "description": "Hello Robot Stretch (original, mobile manipulator)", + "category": "mobile", + "joints": 18, + "asset": { + "dir": "hello_robot_stretch", + "model_xml": "stretch.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "stretch_mj_description" + }, + "aliases": [ + "hello_robot_stretch_original" + ] + }, "stretch3": { "description": "Hello Robot Stretch 3 (mobile manipulator)", "category": "mobile", @@ -497,74 +907,61 @@ "hello_robot_stretch_3" ] }, - "google_robot": { - "description": "Google Robot (mobile base + arm, RT-X)", - "category": "mobile_manip", - "joints": 10, + "tiago_dual": { + "description": "PAL Robotics TIAGo++ Dual-Arm Mobile (26-DOF)", + "category": "mobile", + "joints": 26, "asset": { - "dir": "google_robot", - "model_xml": "robot.xml", - "scene_xml": "scene.xml" + "dir": "pal_tiago_dual", + "model_xml": "tiago_dual.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "tiago++_mj_description" }, "aliases": [ - "oxe_google" + "tiago++", + "pal_tiago_dual" ] }, - "lekiwi": { - "description": "LeKiwi mobile robot", - "category": "mobile", - "hardware": { - "lerobot_type": "lekiwi" - } - }, - "reachy2": { - "description": "Pollen Reachy 2", - "category": "humanoid", - "hardware": { - "lerobot_type": "reachy2" - } - }, - "hope_jr": { - "description": "Hope Junior arm", - "category": "arm", - "hardware": { - "lerobot_type": "hope_jr" - } - }, - "earthrover": { - "description": "EarthRover Mini Plus (mobile outdoor navigation)", + "unitree_a1": { + "description": "Unitree A1 Quadruped", "category": "mobile", - "hardware": { - "lerobot_type": "earthrover" + "joints": 16, + "asset": { + "dir": "unitree_a1", + "model_xml": "a1.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "a1_mj_description" }, "aliases": [ - "earth_rover", - "earthrover_mini_plus", - "frodobots" + "a1" ] }, - "omx": { - "description": "OMX Robot Arm (ROBOTIS, CAN bus motors)", - "category": "arm", - "hardware": { - "lerobot_type": "omx" + "unitree_go2": { + "description": "Unitree Go2 Quadruped", + "category": "mobile", + "joints": 40, + "asset": { + "dir": "unitree_go2", + "model_xml": "go2.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "go2_mj_description" }, "aliases": [ - "omx_follower", - "omx_robot", - "robotis_omx" + "go2" ] }, - "bi_openarm": { - "description": "Bi-manual OpenArm (dual-arm coordination)", - "category": "bimanual", - "hardware": { - "lerobot_type": "bi_openarm" + "google_robot": { + "description": "Google Robot (mobile base + arm, RT-X)", + "category": "mobile_manip", + "joints": 10, + "asset": { + "dir": "google_robot", + "model_xml": "robot.xml", + "scene_xml": "scene.xml", + "auto_download": false }, "aliases": [ - "bi_openarm_follower", - "dual_openarm", - "openarm_bimanual" + "oxe_google" ] } } diff --git a/strands_robots/registry/user_registry.py b/strands_robots/registry/user_registry.py new file mode 100644 index 0000000..eb55843 --- /dev/null +++ b/strands_robots/registry/user_registry.py @@ -0,0 +1,321 @@ +"""User-local robot registry — runtime registration without editing package JSON. + +Provides ``register_robot()`` and ``unregister_robot()`` for adding custom +robots that persist across sessions via a ``user_robots.json`` file stored +alongside the asset cache. + +File location (in priority order): + 1. ``$STRANDS_BASE_DIR/user_robots.json`` + 2. ``~/.strands_robots/user_robots.json`` + +Note: + ``STRANDS_ASSETS_DIR`` only controls where *assets* live, not the + user registry. Use ``STRANDS_BASE_DIR`` to relocate user metadata. + +At load time the user overlay is merged *on top of* the package +``robots.json`` — user entries win on name collision, so you can also +override built-in robots locally. + +Usage:: + + from strands_robots.registry import register_robot, unregister_robot + + # Register a custom robot with MJCF + register_robot( + name="my_arm", + model_xml="my_arm.xml", + description="My custom 6-DOF arm", + category="arm", + joints=6, + asset_dir="my_arm", # resolved relative to assets dir + ) + + # Now works everywhere: + from strands_robots.simulation import create_simulation + sim = create_simulation() + sim.create_world() + sim.add_robot("my_arm") # ✅ auto-resolved + + # Remove it + unregister_robot("my_arm") +""" + +import json +import logging +from pathlib import Path +from typing import Any + +from strands_robots.utils import get_base_dir, resolve_asset_path + +from .loader import invalidate_cache + +logger = logging.getLogger(__name__) + + +def _get_user_registry_path() -> Path: + """Get path to the user-local robot registry file.""" + return get_base_dir() / "user_robots.json" + + +def _load_user_registry() -> dict[str, Any]: + """Load the user-local robot registry file. + + Returns: + Dict with ``"robots"`` key mapping names to robot definitions. + """ + path = _get_user_registry_path() + if not path.exists(): + return {"robots": {}} + try: + with open(path, encoding="utf-8") as f: + data = json.load(f) + if "robots" not in data: + data = {"robots": {}} + return data + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Failed to load user registry %s: %s", path, exc) + return {"robots": {}} + + +def _save_user_registry(data: dict[str, Any]) -> None: + """Save the user-local robot registry file.""" + path = _get_user_registry_path() + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=4) + f.write("\n") + logger.info("Saved user registry: %s (%d robots)", path, len(data.get("robots", {}))) + + +def get_user_robots() -> dict[str, Any]: + """Get all user-registered robots. + + Returns: + Dict mapping robot names to their definitions. + """ + return _load_user_registry().get("robots", {}) + + +def register_robot( + name: str, + *, + model_xml: str, + description: str = "", + category: str = "arm", + joints: int = 0, + asset_dir: str | None = None, + scene_xml: str | None = None, + aliases: list[str] | None = None, + robot_descriptions_module: str | None = None, + hardware: dict[str, Any] | None = None, + overwrite: bool = False, +) -> dict[str, Any]: + """Register a custom robot in the user-local registry. + + .. warning:: Security + + This function is a **library-only** API and must NOT be exposed + as an agent @tool without additional safeguards. A malicious + agent could register a robot pointing to attacker-controlled MJCF + that executes code via MuJoCo plugins. If tool exposure is needed + in the future, gate it behind STRANDS_TRUST_REMOTE_CODE and + validate all paths with _safe_join. + + The robot becomes immediately available in ``get_robot()``, + ``list_robots()``, ``resolve_model_path()``, ``sim.add_robot()``, etc. + + Args: + name: Canonical robot name (lowercase, underscores). + model_xml: Path to MJCF/URDF model file, relative to ``asset_dir``. + description: Human-readable description. + category: Robot category (arm, humanoid, mobile, hand, aerial, bimanual, ...). + joints: Number of actuated joints. + asset_dir: Directory containing the model file and meshes. + - Absolute path: used as-is (``~/`` expanded). + - Relative path: resolved against the assets directory + (``STRANDS_ASSETS_DIR`` or ``~/.strands_robots/assets/``). + - None: defaults to ``//``. + scene_xml: Scene XML (with ground/lights). Defaults to ``model_xml``. + aliases: Alternative names for this robot. + robot_descriptions_module: Optional ``robot_descriptions`` module name. + hardware: Optional hardware config dict (``lerobot_type``, etc.). + overwrite: If False (default), raises ValueError if robot already exists. + + Returns: + The registered robot definition dict. + + Raises: + ValueError: If name already exists and ``overwrite`` is False. + FileNotFoundError: If ``model_xml`` doesn't exist at the resolved path. + + Example:: + + register_robot( + name="my_arm", + model_xml="my_arm.xml", + asset_dir="~/robots/my_arm_v2", + description="Custom 6-DOF arm with gripper", + category="arm", + joints=7, + aliases=["myarm", "custom_arm"], + ) + """ + # Normalize name + name = name.lower().strip().replace("-", "_") + + # Load existing + data = _load_user_registry() + + # Check for existing (in user registry AND package registry) + if not overwrite: + if name in data.get("robots", {}): + raise ValueError(f"Robot '{name}' already in user registry. Use overwrite=True to replace.") + # Also check package registry + try: + from .robots import get_robot as _pkg_get_robot + + if _pkg_get_robot(name) is not None: + logger.info( + "Robot '%s' exists in package registry — user registration will override it.", + name, + ) + except ImportError: + pass + + # Resolve asset_dir via shared utility (respects STRANDS_ASSETS_DIR) + resolved_dir = resolve_asset_path(asset_dir, default_name=name) + + # Use the directory name as the asset "dir" key (relative to search paths) + # This matches how resolve_model_path works: search_dir / asset["dir"] / xml + dir_name = resolved_dir.name + + # Alias collision detection — warn (don't fail) when a user alias shadows a + # canonical name or another alias. Doing this at registration surfaces the + # problem immediately instead of at silent resolution-order time. + if aliases and not overwrite: + try: + from .robots import get_robot as _pkg_get_robot + from .robots import list_robots as _pkg_list_robots + + pkg_canonical = {r["name"] for r in _pkg_list_robots()} + pkg_aliases: set[str] = set() + for r in _pkg_list_robots(): + pkg_aliases.update(r.get("aliases", []) or []) + except Exception: + pkg_canonical = set() + pkg_aliases = set() + + user_existing = data.get("robots", {}) + user_canonical = set(user_existing.keys()) + user_aliases: set[str] = set() + for _r in user_existing.values(): + user_aliases.update(_r.get("aliases", []) or []) + + for alias in aliases: + if alias in pkg_canonical or alias in user_canonical: + logger.warning("Alias %r shadows an existing robot canonical name.", alias) + elif alias in pkg_aliases or alias in user_aliases: + logger.warning("Alias %r is already used by another robot.", alias) + + # Validate model_xml exists. Previously we only checked when + # ``resolved_dir`` existed — which silently accepted registrations for + # dirs that didn't exist yet and surfaced a confusing error only at + # ``add_robot()`` time. Now we fail-closed on both conditions so the + # user gets an immediate, actionable error at registration time. + model_path = resolved_dir / model_xml + if not resolved_dir.exists(): + raise FileNotFoundError( + f"Asset directory does not exist: {resolved_dir}\n" + f"Create the directory and place '{model_xml}' inside it before registering." + ) + if not model_path.exists(): + raise FileNotFoundError(f"Model XML not found: {model_path}\nEnsure '{model_xml}' exists in '{resolved_dir}'") + + # Build entry + entry: dict[str, Any] = { + "description": description, + "category": category, + "joints": joints, + "asset": { + "dir": dir_name, + "model_xml": model_xml, + "scene_xml": scene_xml or model_xml, + }, + } + + if robot_descriptions_module: + entry["asset"]["robot_descriptions_module"] = robot_descriptions_module + + if aliases: + entry["aliases"] = aliases + + if hardware: + entry["hardware"] = hardware + + # Store the full resolved path so the asset manager can find it + # even if the dir isn't in the standard search paths + entry["_user_asset_path"] = str(resolved_dir) + + # Save + data.setdefault("robots", {})[name] = entry + _save_user_registry(data) + + # Invalidate loader cache so next get_robot() picks up the merge + _invalidate_cache() + + logger.info("Registered robot '%s' → %s/%s", name, dir_name, model_xml) + return entry + + +def unregister_robot(name: str) -> bool: + """Remove a robot from the user-local registry. + + Does not affect the package ``robots.json``. If the robot exists + only in the package registry, this is a no-op. + + Args: + name: Robot name to remove. + + Returns: + True if the robot was removed, False if it wasn't in the user registry. + """ + name = name.lower().strip().replace("-", "_") + data = _load_user_registry() + + if name not in data.get("robots", {}): + logger.info("Robot '%s' not in user registry — nothing to remove.", name) + return False + + del data["robots"][name] + _save_user_registry(data) + _invalidate_cache() + + logger.info("Unregistered robot '%s'", name) + return True + + +def list_user_robots() -> list[dict[str, Any]]: + """List all user-registered robots. + + Returns: + List of dicts with name, description, category, path info. + """ + robots = get_user_robots() + result = [] + for name, info in sorted(robots.items()): + result.append( + { + "name": name, + "description": info.get("description", ""), + "category": info.get("category", ""), + "joints": info.get("joints", 0), + "asset_dir": info.get("_user_asset_path", ""), + "model_xml": info.get("asset", {}).get("model_xml", ""), + } + ) + return result + + +def _invalidate_cache() -> None: + """Invalidate the loader cache so merged data is reloaded.""" + invalidate_cache("robots") diff --git a/strands_robots/simulation/__init__.py b/strands_robots/simulation/__init__.py new file mode 100644 index 0000000..d9674a9 --- /dev/null +++ b/strands_robots/simulation/__init__.py @@ -0,0 +1,111 @@ +"""Strands Robots Simulation — multi-backend simulation framework. + +Architecture:: + + simulation/ + ├── __init__.py ← this file (re-exports, lazy loading) + ├── base.py ← SimEngine ABC + ├── factory.py ← create_simulation() + backend registration + ├── models.py ← shared dataclasses (SimWorld, SimRobot, ...) + └── model_registry.py ← URDF/MJCF resolution (shared across backends) + + # MuJoCo backend added in subsequent PRs. + +Usage:: + + # Default (MuJoCo) via factory + from strands_robots.simulation import create_simulation + sim = create_simulation() + + # Direct class access + from strands_robots.simulation import Simulation + sim = Simulation() + + # Explicit backend + from strands_robots.simulation.mujoco import MuJoCoSimulation + + # Shared types (no heavy deps) + from strands_robots.simulation import SimWorld, SimRobot, SimObject + + # ABC for custom backends + from strands_robots.simulation.base import SimEngine + +Future backends:: + + from strands_robots.simulation.isaac import IsaacSimulation + from strands_robots.simulation.newton import NewtonSimulation +""" + +import importlib as _importlib +from typing import Any + +# --- Light imports (no heavy deps — stdlib + dataclasses only) --- +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.factory import ( + create_simulation, + list_backends, + register_backend, +) +from strands_robots.simulation.model_registry import ( + list_available_models, + list_registered_urdfs, + register_urdf, + resolve_model, + resolve_urdf, +) +from strands_robots.simulation.models import ( + SimCamera, + SimObject, + SimRobot, + SimStatus, + SimWorld, + TrajectoryStep, +) + +# --- Heavy imports (lazy — loaded when mujoco backend is available) --- +# MuJoCo-specific lazy imports will be added when the mujoco/ subpackage +# is introduced. For now, only the lightweight foundation is available. +_LAZY_IMPORTS: dict[str, tuple[str, str]] = {} + + +__all__ = [ + # ABC + "SimEngine", + # Factory + "create_simulation", + "list_backends", + "register_backend", + # Default backend alias (available when mujoco backend is installed) + # "Simulation", + # "MuJoCoSimulation", + # Shared dataclasses + "SimStatus", + "SimRobot", + "SimObject", + "SimCamera", + "SimWorld", + "TrajectoryStep", + # MuJoCo builder (available when mujoco backend is installed) + # "MJCFBuilder", + # Model registry + "register_urdf", + "resolve_model", + "resolve_urdf", + "list_registered_urdfs", + "list_available_models", +] + + +def __getattr__(name: str) -> Any: + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + module = _importlib.import_module(module_path) + value = getattr(module, attr_name) + globals()[name] = value + return value + raise AttributeError(f"module 'strands_robots.simulation' has no attribute {name!r}") + + +# NOTE: MuJoCo GL backend configuration lives in the top-level +# strands_robots/__init__.py to ensure it runs before any `import mujoco`. +# Do NOT duplicate it here — see PR #86 for the canonical location. diff --git a/strands_robots/simulation/base.py b/strands_robots/simulation/base.py new file mode 100644 index 0000000..7ca2098 --- /dev/null +++ b/strands_robots/simulation/base.py @@ -0,0 +1,222 @@ +"""Simulation ABC — backend-agnostic interface for all simulation engines. + +Every simulation backend (MuJoCo, Isaac, Newton) implements this interface. +Agent tools and the Robot() factory interact through these methods only — +they never touch backend-specific APIs directly. + +Usage:: + + from strands_robots.simulation import Simulation # returns MuJoCo by default + + # Or explicitly: + from strands_robots.simulation.mujoco import MuJoCoSimulation + + # Future: + from strands_robots.simulation.isaac import IsaacSimulation + from strands_robots.simulation.newton import NewtonSimulation +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger(__name__) + + +class SimEngine(ABC): + """Abstract base class for simulation engines. + + Defines the contract that all backends (MuJoCo, Isaac, Newton) must + implement. This is the *programmatic* API — the AgentTool layer + wraps it with tool_spec/stream for LLM access. + + Method categories: + + **Required** (``@abstractmethod``): Core simulation loop — world + lifecycle, entity management, observation/action, rendering. Every + physics engine must implement these to be usable. + + **Optional** (default raises ``NotImplementedError``): Higher-level + features — scene loading, policy running, domain randomization, + contact queries. Backends opt in by overriding only what they support. + + Lifecycle:: + + sim = SomeEngine() + sim.create_world() + sim.add_robot("so100", data_config="so100") + sim.add_object("cube", shape="box", position=[0.3, 0, 0.05]) + + # Control loop + obs = sim.get_observation("so100") + sim.send_action({"joint_0": 0.5}, robot_name="so100") + sim.step(n_steps=10) + + # Render + result = sim.render(camera_name="default") + + # Cleanup + sim.destroy() + """ + + # --- World lifecycle --- + + @abstractmethod + def create_world( + self, + timestep: float | None = None, + gravity: list[float] | None = None, + ground_plane: bool = True, + ) -> dict[str, Any]: + """Create a new simulation world.""" + ... + + @abstractmethod + def destroy(self) -> dict[str, Any]: + """Destroy the simulation world and release resources.""" + ... + + @abstractmethod + def reset(self) -> dict[str, Any]: + """Reset simulation to initial state.""" + ... + + @abstractmethod + def step(self, n_steps: int = 1) -> dict[str, Any]: + """Advance simulation by n physics steps.""" + ... + + @abstractmethod + def get_state(self) -> dict[str, Any]: + """Get full simulation state summary.""" + ... + + # --- Robot management --- + + @abstractmethod + def add_robot( + self, + name: str, + urdf_path: str | None = None, + data_config: str | None = None, + position: list[float] | None = None, + orientation: list[float] | None = None, + ) -> dict[str, Any]: + """Add a robot to the simulation.""" + ... + + @abstractmethod + def remove_robot(self, name: str) -> dict[str, Any]: + """Remove a robot from the simulation.""" + ... + + # --- Object management --- + + @abstractmethod + def add_object( + self, + name: str, + shape: str = "box", + position: list[float] | None = None, + orientation: list[float] | None = None, + size: list[float] | None = None, + color: list[float] | None = None, + mass: float = 0.1, + is_static: bool = False, + mesh_path: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Add an object to the scene.""" + ... + + @abstractmethod + def remove_object(self, name: str) -> dict[str, Any]: + """Remove an object from the scene.""" + ... + + # --- Observation / Action --- + + @abstractmethod + def get_observation(self, robot_name: str | None = None, camera_name: str | None = None) -> dict[str, Any]: + """Get observation from simulation. + + Convenience method that delegates to the underlying Robot + abstraction. Provides a unified interface for agent tools + that interact with simulation without needing to distinguish + between Robot and Sim layers. + """ + ... + + @abstractmethod + def send_action(self, action: dict[str, Any], robot_name: str | None = None, n_substeps: int = 1) -> None: + """Apply action to simulation. + + Convenience method that delegates to the underlying Robot + abstraction. The simulation engine acts as a facade so agent + tools can use ``sim.send_action()`` without knowing about + the Robot/Policy layer. + """ + ... + + # --- Rendering --- + + @abstractmethod + def render( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: + """Render a camera view. + + Returns dict with ``"image"`` key (numpy array, RGB uint8) and + optional ``"depth"`` key (float32 depth map). Resolution comes + from camera config unless ``width``/``height`` are given. + """ + ... + + # --- Optional overrides (have default no-op implementations) --- + + def load_scene(self, scene_path: str) -> dict[str, Any]: + """Load a complete scene from file. Override per backend.""" + raise NotImplementedError("load_scene not implemented by this backend") + + def run_policy(self, robot_name: str, policy_provider: str = "mock", **kwargs: Any) -> dict[str, Any]: + """Run a policy loop in the simulation. + + Orchestration shortcut: internally creates a Policy, then loops + ``obs → policy(obs) → send_action(action) → step()``. + Intentionally placed on SimEngine as a facade for agent tools + that need a single ``simulation(action="run_policy")`` interface. + Override per backend. + """ + raise NotImplementedError("run_policy not implemented by this backend") + + def randomize(self, **kwargs: Any) -> dict[str, Any]: + """Apply domain randomization. + + Concrete backends define their own parameter signatures. + Override per backend. + """ + raise NotImplementedError("randomize not implemented by this backend") + + def get_contacts(self) -> dict[str, Any]: + """Get contact information. Override per backend.""" + raise NotImplementedError("get_contacts not implemented by this backend") + + def cleanup(self) -> None: + """Release all resources. Called on __del__ / context exit.""" + pass + + def __enter__(self) -> SimEngine: + return self + + def __exit__(self, *exc: object) -> None: + self.cleanup() + + def __del__(self) -> None: + try: + self.cleanup() + except Exception as e: + # Best-effort cleanup during GC — exceptions can't propagate + # from __del__ (CPython ignores them), so log for visibility. + logger.warning("Cleanup error during __del__: %s", e) diff --git a/strands_robots/simulation/factory.py b/strands_robots/simulation/factory.py new file mode 100644 index 0000000..e7b0a5b --- /dev/null +++ b/strands_robots/simulation/factory.py @@ -0,0 +1,233 @@ +"""Simulation factory — create_simulation() and runtime backend registration. + +Mirrors the policy factory pattern: JSON-driven defaults with runtime +override capability. Backends are lazy-loaded on first use. + +Usage:: + + from strands_robots.simulation import create_simulation + + # Default backend (MuJoCo) + sim = create_simulation() + + # Explicit backend + sim = create_simulation("mujoco", timestep=0.001) + + # Future backends + sim = create_simulation("isaac", gpu_id=0) + sim = create_simulation("newton") + + # Custom backend (runtime-registered) + from strands_robots.simulation.factory import register_backend + register_backend("my_sim", lambda: MySimBackend, aliases=["custom"]) + sim = create_simulation("custom") +""" + +from __future__ import annotations + +import importlib +import logging +from collections.abc import Callable +from typing import Any + +from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + +# ───────────────────────────────────────────────────────────────────── +# Built-in backend registry (lazy loaders — no imports at module load) +# ───────────────────────────────────────────────────────────────────── + +_BUILTIN_BACKENDS: dict[str, tuple[str, str]] = { + "mujoco": ( + "strands_robots.simulation.mujoco.simulation", + "Simulation", + ), + # Future: + # "isaac": ("strands_robots.simulation.isaac.simulation", "IsaacSimulation"), + # "newton": ("strands_robots.simulation.newton.simulation", "NewtonSimulation"), +} + +_BUILTIN_ALIASES: dict[str, str] = { + "mj": "mujoco", + "mjc": "mujoco", + "mjx": "mujoco", + # "isaac_sim": "isaac", + # "isaacsim": "isaac", + # "nvidia": "isaac", +} + +DEFAULT_BACKEND = "mujoco" + +# ───────────────────────────────────────────────────────────────────── +# Runtime registration (for user-defined backends not in built-ins) +# ───────────────────────────────────────────────────────────────────── + +_runtime_registry: dict[str, Callable[[], type[SimEngine]]] = {} +_runtime_aliases: dict[str, str] = {} + + +def register_backend( + name: str, + loader: Callable[[], type[SimEngine]], + aliases: list[str] | None = None, + force: bool = False, +) -> None: + """Register a custom simulation backend at runtime. + + Use this to add backends without editing source code. + + Args: + name: Backend identifier (e.g., ``"my_physics"``). + loader: Zero-arg callable that returns the backend **class** + (not instance). Called lazily on first ``create_simulation()``. + aliases: Optional short names that resolve to ``name``. + force: If False (default), raises ValueError when ``name`` or + an alias is already registered. Set True to overwrite. + + Raises: + ValueError: If ``name`` or an alias conflicts with an existing + registration and ``force`` is False. + + Example:: + + from strands_robots.simulation.factory import register_backend + + register_backend( + "bullet", + lambda: BulletSimulation, + aliases=["pybullet", "pb"], + ) + sim = create_simulation("bullet") + """ + if not force: + # Check name against ALL existing identifiers (backends + aliases) + if name in _runtime_registry or name in _BUILTIN_BACKENDS: + raise ValueError(f"Backend {name!r} already registered. Use force=True to overwrite.") + if name in _BUILTIN_ALIASES: + raise ValueError( + f"Name {name!r} conflicts with built-in alias (resolves to {_BUILTIN_ALIASES[name]!r}). Use force=True to overwrite." + ) + if name in _runtime_aliases: + raise ValueError( + f"Name {name!r} conflicts with runtime alias (resolves to {_runtime_aliases[name]!r}). Use force=True to overwrite." + ) + if aliases: + for alias in aliases: + if alias in _BUILTIN_BACKENDS or alias in _runtime_registry: + raise ValueError( + f"Alias {alias!r} conflicts with existing backend name. Use force=True to overwrite." + ) + if alias in _BUILTIN_ALIASES: + raise ValueError(f"Alias {alias!r} conflicts with built-in alias. Use force=True to overwrite.") + if alias in _runtime_aliases: + raise ValueError(f"Alias {alias!r} already registered. Use force=True to overwrite.") + + _runtime_registry[name] = loader + if aliases: + for alias in aliases: + _runtime_aliases[alias] = name + logger.debug("Registered simulation backend: %s (aliases=%s)", name, aliases) + + +def list_backends() -> list[str]: + """List all available backend names (built-in + runtime-registered). + + Returns: + Sorted list of unique backend identifiers and aliases. + + Example:: + + >>> list_backends() + ['mj', 'mjc', 'mjx', 'mujoco'] + """ + names: set[str] = set() + names.update(_BUILTIN_BACKENDS.keys()) + names.update(_BUILTIN_ALIASES.keys()) + names.update(_runtime_registry.keys()) + names.update(_runtime_aliases.keys()) + return sorted(names) + + +def _resolve_name(backend: str) -> str: + """Resolve aliases to canonical backend name.""" + # Runtime aliases first (user overrides win) + if backend in _runtime_aliases: + return _runtime_aliases[backend] + # Built-in aliases + if backend in _BUILTIN_ALIASES: + return _BUILTIN_ALIASES[backend] + return backend + + +def _import_backend_class(name: str) -> type[SimEngine]: + """Import and return a backend class by canonical name.""" + # 1. Runtime registry (user-registered) + if name in _runtime_registry: + cls: type[SimEngine] = _runtime_registry[name]() + logger.debug("Loaded runtime backend: %s → %s", name, cls.__name__) + return cls + + # 2. Built-in registry + if name in _BUILTIN_BACKENDS: + module_path, class_name = _BUILTIN_BACKENDS[name] + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError as exc: + raise ImportError( + f"Simulation backend {name!r} is declared in the built-in registry " + f"but its implementation module {module_path!r} is not available. " + f"This usually means the backend has not been installed yet " + f"(e.g. `pip install strands-robots[{name}]`) or the backend " + f"implementation has not landed in this release. " + f"Register a custom backend via " + f"`strands_robots.simulation.factory.register_backend()` to proceed." + ) from exc + backend_cls: type[SimEngine] = getattr(module, class_name) # type: ignore[assignment] + logger.debug("Loaded built-in backend: %s → %s.%s", name, module_path, class_name) + return backend_cls + + raise ValueError(f"Unknown simulation backend: {name!r}. Available: {', '.join(list_backends())}") + + +def create_simulation( + backend: str = DEFAULT_BACKEND, + **kwargs: Any, +) -> SimEngine: + """Create a simulation backend instance. + + This is the primary entry point for creating simulations. + Backend classes are lazy-loaded on first call. + + Args: + backend: Backend name or alias. Defaults to ``"mujoco"``. + Built-in: ``"mujoco"`` (aliases: ``"mj"``, ``"mjc"``, ``"mjx"``). + **kwargs: Backend-specific keyword arguments passed to the + constructor (e.g., ``tool_name``, ``timestep``). + + Returns: + A ``SimEngine`` instance ready for ``create_world()``. + + Raises: + ValueError: If the backend name is not recognized. + ImportError: If the backend's dependencies are missing + (e.g., ``pip install mujoco``). + + Examples:: + + # Default (MuJoCo) + sim = create_simulation() + sim.create_world() + sim.add_robot("so100") + + # With alias + sim = create_simulation("mj") + + # Pass kwargs to backend constructor + sim = create_simulation("mujoco", tool_name="my_sim") + """ + canonical = _resolve_name(backend) + logger.info("Creating simulation: %s (resolved from %r)", canonical, backend) + + BackendClass = _import_backend_class(canonical) + return BackendClass(**kwargs) diff --git a/strands_robots/simulation/model_registry.py b/strands_robots/simulation/model_registry.py new file mode 100644 index 0000000..b7af5e9 --- /dev/null +++ b/strands_robots/simulation/model_registry.py @@ -0,0 +1,142 @@ +"""Robot model resolution — URDF registry + asset manager. + +Bridges the robot registry with actual URDF/MJCF files on disk. + +Resolution order for :func:`resolve_model`: + 1. User-registered URDFs (:func:`register_urdf`) + 2. URDF search paths (``STRANDS_ASSETS_DIR``, CWD, etc.) + 3. Asset manager (``robot_descriptions`` — fallback for standard robots) +""" + +from __future__ import annotations + +import logging +import os + +from strands_robots.utils import get_search_paths + +logger = logging.getLogger(__name__) + +# URDF search paths are resolved lazily via :func:`strands_robots.utils.get_search_paths` +# at every lookup — this avoids snapshotting ``Path.cwd()`` and ``STRANDS_ASSETS_DIR`` +# at import time, which caused silent wrong-path bugs when tests/notebooks chdir after +# import. + +try: + from strands_robots.assets import ( + format_robot_table, + resolve_model_path, + ) + + _HAS_ASSET_MANAGER = True +except ImportError: + _HAS_ASSET_MANAGER = False + +try: + from strands_robots.registry import get_robot, resolve_name + + _HAS_REGISTRY = True +except ImportError: + _HAS_REGISTRY = False + +# Logged lazily on first resolution via _log_configuration_once() — +# avoids noisy INFO on every ``import strands_robots``. +_CONFIG_LOGGED = False + + +def _log_configuration_once() -> None: + global _CONFIG_LOGGED + if _CONFIG_LOGGED: + return + logger.debug("Asset manager available: %s", _HAS_ASSET_MANAGER) + _CONFIG_LOGGED = True + + +# Runtime cache for user-registered URDFs +_URDF_REGISTRY: dict[str, str] = {} + + +def register_urdf(data_config: str, urdf_path: str) -> None: + """Register a URDF/MJCF file for a data_config name.""" + _URDF_REGISTRY[data_config] = urdf_path + logger.info("📋 Registered model for '%s': %s", data_config, urdf_path) + + +def resolve_model(name: str, prefer_scene: bool = True) -> str | None: + """Resolve a robot name or data_config to an MJCF/URDF model path. + + Resolution order (local assets take priority): + 1. User-registered URDFs (custom user registrations) + 2. URDF search paths (STRANDS_ASSETS_DIR, CWD, etc.) + 3. Asset manager (robot_descriptions — fallback for standard robots) + """ + _log_configuration_once() + # 1+2. Check local/custom paths first (user overrides win) + local = resolve_urdf(name) + if local: + return local + + # 3. Fall back to asset manager + if _HAS_ASSET_MANAGER: + path = resolve_model_path(name, prefer_scene=prefer_scene) + if path and path.exists(): + return str(path) + if prefer_scene: + path = resolve_model_path(name, prefer_scene=False) + if path and path.exists(): + return str(path) + + return None + + +def resolve_urdf(data_config: str) -> str | None: + """Resolve a data_config name to a URDF file path. + + Also checks the registry's ``legacy_urdf`` field — a backward-compatible + path for robots that were registered before the MJCF asset system + was introduced (e.g. robots originally configured with raw URDF paths). + """ + if data_config in _URDF_REGISTRY: + urdf_rel = _URDF_REGISTRY[data_config] + if os.path.isabs(urdf_rel) and os.path.exists(urdf_rel): + return str(urdf_rel) + for search_dir in get_search_paths(): + candidate = search_dir / urdf_rel + if candidate.exists(): + return str(candidate) + + if _HAS_REGISTRY: + canonical = resolve_name(data_config) + info = get_robot(canonical) + # ``legacy_urdf``: backward-compatible URDF path from before the + # MJCF asset system was introduced. Kept so that existing + # user configs referencing raw URDF paths continue to work. + if info and "legacy_urdf" in info: + urdf_rel = info["legacy_urdf"] + if os.path.isabs(urdf_rel) and os.path.exists(urdf_rel): + return str(urdf_rel) + for search_dir in get_search_paths(): + candidate = search_dir / urdf_rel + if candidate.exists(): + return str(candidate) + + logger.debug("URDF not found for '%s' in search paths", data_config) + return None + + +def list_registered_urdfs() -> dict[str, str | None]: + """List all registered URDF mappings and their resolved paths.""" + return {config_name: resolve_urdf(config_name) for config_name in _URDF_REGISTRY} + + +def list_available_models() -> str: + """List all available robot models (Menagerie + custom).""" + if _HAS_ASSET_MANAGER: + return str(format_robot_table()) + + lines = ["Registered URDFs:"] + for name, path in _URDF_REGISTRY.items(): + resolved = resolve_urdf(name) + status = "✅" if resolved else "❌" + lines.append(f" {status} {name}: {path}") + return "\n".join(lines) diff --git a/strands_robots/simulation/models.py b/strands_robots/simulation/models.py new file mode 100644 index 0000000..e339d15 --- /dev/null +++ b/strands_robots/simulation/models.py @@ -0,0 +1,143 @@ +"""Dataclasses for simulation state. + +These dataclasses provide a backend-independent typed state representation +consumed by simulation engine implementations (e.g. MuJoCo, Isaac Sim, +PyBullet). + +They enable: + - Type-safe state tracking across simulation steps. + - Serialisation for checkpoints and trajectory recording. + - A backend-independent interface for agent tools. + +They are defined alongside the ``SimEngine`` ABC because its method +signatures reference them (e.g. ``create_world() → SimWorld``). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class SimStatus(Enum): + """Simulation execution status.""" + + IDLE = "idle" + RUNNING = "running" + PAUSED = "paused" + COMPLETED = "completed" + ERROR = "error" + + +@dataclass +class SimRobot: + """A robot instance within the simulation.""" + + name: str + urdf_path: str + position: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) + orientation: list[float] = field(default_factory=lambda: [1.0, 0.0, 0.0, 0.0]) # wxyz quat + data_config: str | None = None + body_id: int = -1 + joint_ids: list[int] = field(default_factory=list) + joint_names: list[str] = field(default_factory=list) + actuator_ids: list[int] = field(default_factory=list) + namespace: str = "" + policy_running: bool = False + policy_steps: int = 0 + policy_instruction: str = "" + + +@dataclass +class SimObject: + """An object in the simulation scene.""" + + name: str + shape: str # "box", "sphere", "cylinder", "capsule", "mesh" + position: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) + orientation: list[float] = field(default_factory=lambda: [1.0, 0.0, 0.0, 0.0]) + size: list[float] = field(default_factory=lambda: [0.05, 0.05, 0.05]) + color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 1.0]) # RGBA + mass: float = 0.1 + mesh_path: str | None = None + body_id: int = -1 + is_static: bool = False + _original_position: list[float] = field(default_factory=list) + _original_color: list[float] = field(default_factory=list) + + def __post_init__(self) -> None: + self._original_position = list(self.position) + self._original_color = list(self.color) + + +@dataclass +class SimCamera: + """A camera in the simulation.""" + + name: str + position: list[float] = field(default_factory=lambda: [1.0, 1.0, 1.0]) + target: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) + fov: float = 60.0 + width: int = 640 + height: int = 480 + camera_id: int = -1 + + +@dataclass +class TrajectoryStep: + """A single step in a recorded trajectory.""" + + timestamp: float + sim_time: float + robot_name: str + observation: dict[str, Any] + action: dict[str, Any] + instruction: str = "" + + +@dataclass +class SimWorld: + """Complete simulation world state. + + Backend-independent state with engine-specific internals kept in three + escape hatches, each with a distinct role so backend implementers know + which to use: + + * ``_model``: the physics engine's **core model handle** — the single + compiled/loaded representation of the scene (e.g. ``mujoco.MjModel``, + Isaac's ``Scene``, PyBullet's body registry). Every backend has one. + * ``_data``: the physics engine's **core simulation state handle** — + the mutable per-step state companion to ``_model`` + (e.g. ``mujoco.MjData``, Isaac's ``World``). Every backend has one. + * ``_backend_state``: a **catch-all dict** for everything else the + backend needs to persist — generated XML, temp dirs, recording + buffers, caches, etc. Prefer this over adding new fields here. + + All three are typed ``Any``/``dict`` so nothing leaks engine-specific + types into this base module. + """ + + robots: dict[str, SimRobot] = field(default_factory=dict) + objects: dict[str, SimObject] = field(default_factory=dict) + cameras: dict[str, SimCamera] = field(default_factory=dict) + timestep: float = 0.002 # 500Hz physics + gravity: list[float] = field(default_factory=lambda: [0.0, 0.0, -9.81]) + ground_plane: bool = True + status: SimStatus = SimStatus.IDLE + sim_time: float = 0.0 + step_count: int = 0 + # Engine core handles — set after the backend builds the world. + # Use these for the primary model/state objects only; put everything + # else in ``_backend_state`` below. + _model: Any = None # Engine-specific model handle (e.g. MjModel, Scene) + _data: Any = None # Engine-specific data handle (e.g. MjData, World) + # Catch-all for backend-specific state that isn't the core model/data. + # Examples: generated XML strings, temp dirs, recording buffers + # (``_recording``, ``_trajectory``, ``_dataset_recorder``), caches, etc. + # Prefer this over adding new fields to ``SimWorld``. + _backend_state: dict[str, Any] = field(default_factory=dict) + # Physics state checkpoints (used by save_state/restore_state in PR #85). + # Kept as a top-level field — requested by @yinsong1986 during review to + # avoid monkey-patching when ``reset()`` creates a fresh ``SimWorld``. + _checkpoints: dict[str, Any] = field(default_factory=dict) diff --git a/strands_robots/tools/__init__.py b/strands_robots/tools/__init__.py index c18ccbb..7ae62c0 100644 --- a/strands_robots/tools/__init__.py +++ b/strands_robots/tools/__init__.py @@ -12,6 +12,7 @@ import importlib as _importlib _LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "download_assets": (".download_assets", "download_assets"), "gr00t_inference": (".gr00t_inference", "gr00t_inference"), "lerobot_calibrate": (".lerobot_calibrate", "lerobot_calibrate"), "lerobot_camera": (".lerobot_camera", "lerobot_camera"), diff --git a/strands_robots/tools/download_assets.py b/strands_robots/tools/download_assets.py new file mode 100644 index 0000000..2f59adf --- /dev/null +++ b/strands_robots/tools/download_assets.py @@ -0,0 +1,84 @@ +"""Download robot model assets — Strands Agent ``@tool`` wrapper. + +Thin wrapper around :mod:`strands_robots.assets.download` that exposes +``download_robots()`` as an agent tool. All download logic lives in the +``assets.download`` module; this file only handles input parsing and +output formatting for the Strands Agent SDK. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from strands.tools.decorator import tool + +from strands_robots.assets.download import download_robots, get_user_assets_dir +from strands_robots.assets.manager import list_available_robots +from strands_robots.registry import format_robot_table + +logger = logging.getLogger(__name__) + + +@tool +def download_assets( + action: str = "download", + robots: str | None = None, + category: str | None = None, + force: bool = False, +) -> dict[str, Any]: + """Download and manage robot model assets (MJCF XML + meshes). + + Assets are sourced from ``robot_descriptions`` (recommended by MuJoCo + Menagerie, requires ``pip install strands-robots[sim-mujoco]``). When + ``robot_descriptions`` is unavailable, falls back to a shallow + ``git clone`` of the Menagerie repo. Robots with a custom GitHub + source in the registry are cloned from their respective repos. + + Downloaded assets are cached in ``~/.strands_robots/assets/`` + (override with ``STRANDS_ASSETS_DIR``). + + Args: + action: ``download`` | ``list`` | ``status`` + robots: Comma-separated names (e.g. ``so100,panda``). Omit for all. + category: Filter: arm, bimanual, hand, humanoid, mobile, mobile_manip + force: Re-download even if present + """ + try: + if action == "list": + return { + "status": "success", + "content": [{"text": f"🤖 Available Robots:\n\n{format_robot_table()}"}], + } + + if action == "status": + robots_info = list_available_robots() + available = sum(1 for r in robots_info if r["available"]) + lines = [f"📊 {available} available, {len(robots_info) - available} missing"] + lines.extend( + f" {'✅' if r['available'] else '❌'} {r['name']:<20s} {r['category']:<12s} {r['description']}" + for r in robots_info + ) + lines.append(f"\n📁 Cache: {get_user_assets_dir()}") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + if action == "download": + robot_names = [r.strip() for r in robots.split(",") if r.strip()] if robots else None + result = download_robots(names=robot_names, category=category, force=force) + parts = [ + f"📦 Downloaded: {result['downloaded']}, Skipped: {result['skipped']}, Failed: {result['failed']}", + f"Method: {result.get('method', '?')}", + ] + if result.get("failed_details"): + parts.extend(f" ❌ {n}: {r}" for n, r in result["failed_details"].items()) + parts.append(f"📁 Assets: {result.get('assets_dir', '?')}") + return {"status": "success", "content": [{"text": "\n".join(parts)}]} + + return { + "status": "error", + "content": [{"text": f"Unknown action: {action}. Valid: download, list, status"}], + } + + except Exception as exc: + logger.error("download_assets error: %s", exc) + return {"status": "error", "content": [{"text": f"❌ Error: {exc}"}]} diff --git a/strands_robots/utils.py b/strands_robots/utils.py index f2a930c..b56f3f5 100644 --- a/strands_robots/utils.py +++ b/strands_robots/utils.py @@ -2,6 +2,8 @@ import importlib import logging +import os +from pathlib import Path logger = logging.getLogger(__name__) @@ -49,3 +51,133 @@ def require_optional( parts.append(f" pip install 'strands-robots[{extra}]'") parts.append(f" pip install {install_hint}") raise ImportError("\n".join(parts)) from None + + +# ───────────────────────────────────────────────────────────────────── +# Path resolution — single source of truth for all strands-robots paths +# ───────────────────────────────────────────────────────────────────── + +#: Default base directory for all user data. +DEFAULT_BASE_DIR = Path.home() / ".strands_robots" + + +def get_base_dir() -> Path: + """Get the base directory for strands-robots user data. + + Resolution (in priority order): + + 1. ``STRANDS_BASE_DIR`` env var — explicit override. Use this when + you want to relocate *all* strands-robots user data (assets, + user registry, caches) to a non-default location. + 2. ``~/.strands_robots/`` — default. + + Note: + ``STRANDS_ASSETS_DIR`` **only** controls the assets subdirectory + (see :func:`get_assets_dir`). It does *not* move the base dir, + so user-level metadata like ``user_robots.json`` always lands in + a predictable location rather than wherever the assets happen + to be pointed. + + Returns: + Path to the base directory (created if needed). + """ + custom = os.getenv("STRANDS_BASE_DIR") + d = Path(custom) if custom else DEFAULT_BASE_DIR + d.mkdir(parents=True, exist_ok=True) + return d + + +def get_assets_dir() -> Path: + """Get the assets directory (robot model files, meshes, URDFs). + + Resolution: + 1. ``STRANDS_ASSETS_DIR`` env var — used as-is + 2. ``~/.strands_robots/assets/`` — default + + Returns: + Path to the assets directory (created if needed). + """ + custom = os.getenv("STRANDS_ASSETS_DIR") + if custom: + d = Path(custom) + else: + d = DEFAULT_BASE_DIR / "assets" + d.mkdir(parents=True, exist_ok=True) + return d + + +def resolve_asset_path(relative_or_absolute: str | Path | None, default_name: str = "") -> Path: + """Resolve an asset path against the assets directory. + + Args: + relative_or_absolute: Path to resolve. + - ``None`` → ``//`` + - Absolute (or ``~/...``) → expanded as-is + - Relative → ``//`` + default_name: Fallback subdirectory name when path is None. + + Returns: + Resolved absolute Path. + """ + assets = get_assets_dir() + if relative_or_absolute is None: + return assets / default_name + expanded = Path(relative_or_absolute).expanduser() + if expanded.is_absolute(): + return expanded + return assets / expanded + + +# ───────────────────────────────────────────────────────────────────── +# Path safety — prevent traversal via untrusted components +# ───────────────────────────────────────────────────────────────────── + + +def safe_join(base: Path, untrusted: str) -> Path: + """Join *base* with an untrusted relative path, rejecting traversal. + + Used to protect against ``../`` escapes in registry-sourced or + user-supplied path components before they reach the filesystem. + + Args: + base: Trusted base directory. + untrusted: Relative path component (may contain ``/`` but must not + escape *base*). + + Returns: + Normalised absolute Path under *base*. + + Raises: + ValueError: If the resulting path would escape *base*. + + Example:: + + safe_join(Path("/assets"), "robot/model.xml") # OK + safe_join(Path("/assets"), "../etc/passwd") # ValueError + """ + joined = Path(os.path.normpath(base / untrusted)) + base_norm = Path(os.path.normpath(base)) + if not (joined == base_norm or str(joined).startswith(str(base_norm) + os.sep)): + raise ValueError(f"Path traversal blocked: {untrusted!r} escapes {base}") + return joined + + +def get_search_paths() -> list[Path]: + """Get ordered list of asset search paths. + + Used by both :mod:`strands_robots.assets.manager` and + :mod:`strands_robots.assets.download` — centralised here to avoid + a circular dependency between those two modules. + + Order (local assets take priority over defaults): + 1. User asset dir (``STRANDS_ASSETS_DIR`` or ``~/.strands_robots/assets/``) + 2. ``CWD/assets`` (project-local) + """ + paths: list[Path] = [] + user_cache = get_assets_dir() + if user_cache not in paths: + paths.append(user_cache) + cwd_assets = Path.cwd() / "assets" + if cwd_assets not in paths: + paths.append(cwd_assets) + return paths diff --git a/tests/test_registry_integrity.py b/tests/test_registry_integrity.py new file mode 100644 index 0000000..7667631 --- /dev/null +++ b/tests/test_registry_integrity.py @@ -0,0 +1,147 @@ +"""Registry integrity tests — catch silent regressions in robots.json. + +These tests enforce invariants on the robot registry that prevent classes +of bugs like the one flagged by @awsarron on PR #84 review (2026-04-21): +entries where ``robot_descriptions_module`` was accidentally dropped during +the 38→68 robot expansion, silently breaking auto-download. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +REGISTRY_PATH = Path(__file__).parent.parent / "strands_robots" / "registry" / "robots.json" + + +@pytest.fixture(scope="module") +def registry() -> dict: + """Load the robot registry once per module.""" + with open(REGISTRY_PATH) as f: + data = json.load(f) + return data.get("robots", data) + + +def test_registry_loads(registry: dict) -> None: + """Registry file parses as valid JSON with robot entries.""" + assert len(registry) > 0 + + +def test_every_robot_declares_auto_download_strategy(registry: dict) -> None: + """Every robot with an ``asset`` block must declare HOW it gets auto-downloaded. + + Valid options (exactly one required): + 1. ``asset.robot_descriptions_module`` — the robot_descriptions pip module name. + 2. ``asset.source`` with ``type: "github"`` — custom GitHub source block. + 3. ``asset.auto_download: false`` — explicit opt-out (user must supply assets). + + Without one of these, auto-download silently falls through to the + naming-convention heuristic, which fails for most robots and only + logs a warning. This was the trossen_wxai + google_robot regression. + """ + offenders = [] + for name, info in registry.items(): + asset = info.get("asset") + if not asset: + continue # No asset block — nothing to auto-download. + + has_rd = "robot_descriptions_module" in asset + has_source = isinstance(asset.get("source"), dict) and asset["source"].get("type") == "github" + opts_out = asset.get("auto_download") is False + + if not (has_rd or has_source or opts_out): + offenders.append(name) + + assert not offenders, ( + "Robots missing auto-download strategy (add `robot_descriptions_module`, " + "`source: {type: github, ...}`, or `auto_download: false`): " + ", ".join(offenders) + ) + + +def test_asset_dirs_are_unique(registry: dict) -> None: + """No two robots should share the same asset directory name.""" + dir_counts: dict[str, list[str]] = {} + for name, info in registry.items(): + asset_dir = info.get("asset", {}).get("dir") + if asset_dir: + dir_counts.setdefault(asset_dir, []).append(name) + + duplicates = {d: names for d, names in dir_counts.items() if len(names) > 1} + assert not duplicates, f"Duplicate asset dirs: {duplicates}" + + +def test_no_path_traversal_in_asset_paths(registry: dict) -> None: + """Registry-sourced paths must not contain ``..`` (path-traversal defense in depth).""" + for name, info in registry.items(): + asset = info.get("asset", {}) + for key in ("dir", "model_xml", "scene_xml"): + value = asset.get(key, "") + assert ".." not in str(value).split("/"), f"{name}.asset.{key} contains '..': {value!r}" + + +def test_auto_download_false_is_bool_not_string(registry: dict) -> None: + """``auto_download`` must be a proper JSON boolean, not the string ``"false"``.""" + for name, info in registry.items(): + ad = info.get("asset", {}).get("auto_download") + if ad is not None: + assert isinstance(ad, bool), f"{name}.asset.auto_download must be bool, got {type(ad).__name__}: {ad!r}" + + +def _all_canonical_names(registry: dict) -> set[str]: + return set(registry.keys()) + + +def _collect_aliases(registry: dict) -> dict[str, str]: + """Return mapping of alias → owning robot name.""" + out: dict[str, str] = {} + for name, info in registry.items(): + for alias in info.get("aliases", []) or []: + out.setdefault(alias, name) + return out + + +def test_aliases_unique_across_registry(registry: dict) -> None: + """No two robots may declare the same alias — last-loaded would silently win.""" + seen: dict[str, str] = {} + collisions: list[str] = [] + for name, info in registry.items(): + for alias in info.get("aliases", []) or []: + if alias in seen and seen[alias] != name: + collisions.append(f"{alias!r} used by {seen[alias]} AND {name}") + seen[alias] = name + assert not collisions, "Alias collisions:\n " + "\n ".join(collisions) + + +def test_no_alias_shadows_canonical_name(registry: dict) -> None: + """An alias must not equal the canonical name of another robot. + + Shadowing causes resolution order to silently determine the winner, which + is fragile — a future reorder of robots.json could flip which robot a + name resolves to. + """ + canonical = _all_canonical_names(registry) + shadows: list[str] = [] + for name, info in registry.items(): + for alias in info.get("aliases", []) or []: + if alias in canonical and alias != name: + shadows.append(f"{name}.aliases contains {alias!r} which is a canonical robot name") + assert not shadows, "Alias shadows canonical:\n " + "\n ".join(shadows) + + +def test_hardware_only_robots_declare_lerobot_type(registry: dict) -> None: + """Robots without an ``asset`` block must still declare a LeRobot hardware type. + + Prevents silent typos in ``hardware.lerobot_type`` — catches a misspelled + type during registry expansion rather than at teleop time. + """ + offenders: list[str] = [] + for name, info in registry.items(): + if "asset" in info: + continue + hw = info.get("hardware") or {} + lerobot_type = hw.get("lerobot_type") + if not isinstance(lerobot_type, str) or not lerobot_type.strip(): + offenders.append(name) + assert not offenders, "Hardware-only robots missing 'hardware.lerobot_type': " + ", ".join(offenders) diff --git a/tests/test_simulation_factory.py b/tests/test_simulation_factory.py new file mode 100644 index 0000000..f8b8cd6 --- /dev/null +++ b/tests/test_simulation_factory.py @@ -0,0 +1,140 @@ +"""Tests for strands_robots.simulation.factory. + +Regression tests for the built-in-backend-missing case and runtime +registration contracts. +""" + +from __future__ import annotations + +import pytest + +from strands_robots.simulation import base as _base +from strands_robots.simulation import factory as _factory + + +@pytest.fixture(autouse=True) +def _clean_runtime_registry(): + """Snapshot + restore runtime registry so tests don't leak state.""" + saved_reg = dict(_factory._runtime_registry) + saved_al = dict(_factory._runtime_aliases) + yield + _factory._runtime_registry.clear() + _factory._runtime_registry.update(saved_reg) + _factory._runtime_aliases.clear() + _factory._runtime_aliases.update(saved_al) + + +def test_default_backend_missing_raises_import_error_with_guidance() -> None: + """When the built-in ``mujoco`` backend module is not installed, we must + raise :class:`ImportError` with an actionable message — **not** a cryptic + ``ModuleNotFoundError`` from deep inside importlib. + """ + # Remove any cached module so we reliably hit the import path. + import sys + + sys.modules.pop("strands_robots.simulation.mujoco", None) + sys.modules.pop("strands_robots.simulation.mujoco.simulation", None) + + with pytest.raises(ImportError) as exc: + _factory.create_simulation() + + msg = str(exc.value) + assert "mujoco" in msg.lower() + assert "register_backend" in msg or "install" in msg.lower() + + +def test_register_backend_loader_must_be_callable() -> None: + """``register_backend`` requires a *loader* (zero-arg callable returning a + class), not the class itself — passing the class directly currently works + only because ``FakeBackend()`` happens to construct an instance. This + test pins the contract so future refactors can't regress into accepting + both and silently doing the wrong thing. + """ + + class FakeBackend(_base.SimEngine): + def create_world(self, **kw): # type: ignore[override] + return {} + + def destroy(self): # type: ignore[override] + return {} + + def reset(self): # type: ignore[override] + return {} + + def step(self, n_steps: int = 1): # type: ignore[override] + return {} + + def get_state(self): # type: ignore[override] + return {} + + def add_robot(self, name, **kw): # type: ignore[override] + return {} + + def remove_robot(self, name): # type: ignore[override] + return {} + + def add_object(self, name, **kw): # type: ignore[override] + return {} + + def remove_object(self, name): # type: ignore[override] + return {} + + def get_observation(self, robot_name=None, camera_name=None): # type: ignore[override] + return {} + + def send_action(self, action, robot_name=None, n_substeps=1): # type: ignore[override] + return None + + def render(self, camera_name="default", width=None, height=None): # type: ignore[override] + return {} + + # Correct usage — loader returns the class + _factory.register_backend("fake_sim", lambda: FakeBackend) + sim = _factory.create_simulation("fake_sim") + assert isinstance(sim, FakeBackend) + + +def test_register_backend_rejects_duplicate_without_force() -> None: + _factory.register_backend("dup_sim", lambda: _FakeMinimal) + with pytest.raises(ValueError): + _factory.register_backend("dup_sim", lambda: _FakeMinimal) + + +class _FakeMinimal(_base.SimEngine): + """Minimal concrete backend used across assertion fixtures.""" + + def create_world(self, **kw): # type: ignore[override] + return {} + + def destroy(self): # type: ignore[override] + return {} + + def reset(self): # type: ignore[override] + return {} + + def step(self, n_steps: int = 1): # type: ignore[override] + return {} + + def get_state(self): # type: ignore[override] + return {} + + def add_robot(self, name, **kw): # type: ignore[override] + return {} + + def remove_robot(self, name): # type: ignore[override] + return {} + + def add_object(self, name, **kw): # type: ignore[override] + return {} + + def remove_object(self, name): # type: ignore[override] + return {} + + def get_observation(self, robot_name=None, camera_name=None): # type: ignore[override] + return {} + + def send_action(self, action, robot_name=None, n_substeps=1): # type: ignore[override] + return None + + def render(self, camera_name="default", width=None, height=None): # type: ignore[override] + return {} diff --git a/tests/test_simulation_foundation.py b/tests/test_simulation_foundation.py new file mode 100644 index 0000000..e3fdb1c --- /dev/null +++ b/tests/test_simulation_foundation.py @@ -0,0 +1,290 @@ +"""Tests for simulation foundation — models, ABC, factory, model_registry. + +These tests verify the lightweight simulation abstractions without +requiring MuJoCo or any heavy dependencies. +""" + +from typing import Any + +import pytest + +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.factory import ( + create_simulation, + list_backends, + register_backend, +) +from strands_robots.simulation.models import ( + SimObject, + SimRobot, + SimStatus, + SimWorld, + TrajectoryStep, +) + +# ── Shared fixtures ────────────────────────────────────────────── + + +def _make_dummy_engine_class() -> type[SimEngine]: + """Create a minimal concrete SimEngine subclass. + + All 12 required abstract methods return empty dicts / None. + Factored out to avoid ~150 lines of repetition across tests. + """ + + class Dummy(SimEngine): + def create_world( + self, + timestep: float | None = None, + gravity: list[float] | None = None, + ground_plane: bool = True, + ) -> dict[str, Any]: + return {} + + def destroy(self) -> dict[str, Any]: + return {} + + def reset(self) -> dict[str, Any]: + return {} + + def step(self, n_steps: int = 1) -> dict[str, Any]: + return {} + + def get_state(self) -> dict[str, Any]: + return {} + + def add_robot( + self, + name: str, + urdf_path: str | None = None, + data_config: str | None = None, + position: list[float] | None = None, + orientation: list[float] | None = None, + ) -> dict[str, Any]: + return {} + + def remove_robot(self, name: str) -> dict[str, Any]: + return {} + + def add_object( + self, + name: str, + shape: str = "box", + position: list[float] | None = None, + orientation: list[float] | None = None, + size: list[float] | None = None, + color: list[float] | None = None, + mass: float = 0.1, + is_static: bool = False, + mesh_path: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + return {} + + def remove_object(self, name: str) -> dict[str, Any]: + return {} + + def get_observation(self, robot_name: str | None = None, camera_name: str | None = None) -> dict[str, Any]: + return {} + + def send_action(self, action: dict[str, Any], robot_name: str | None = None, n_substeps: int = 1) -> None: + return None + + def render( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: + return {} + + return Dummy + + +@pytest.fixture +def dummy_engine_class() -> type[SimEngine]: + """Fixture providing a minimal concrete SimEngine subclass.""" + return _make_dummy_engine_class() + + +# ── ABC Tests ──────────────────────────────────────────────────── + + +class TestSimEngine: + """Test the abstract base class contract.""" + + def test_cannot_instantiate_abc(self): + with pytest.raises(TypeError): + SimEngine() + + def test_has_required_abstract_methods(self): + abstract_methods = SimEngine.__abstractmethods__ + expected = { + "create_world", + "destroy", + "reset", + "step", + "get_state", + "add_robot", + "remove_robot", + "add_object", + "remove_object", + "get_observation", + "send_action", + "render", + } + assert expected == abstract_methods + + def test_optional_methods_raise_not_implemented(self, dummy_engine_class): + """Optional methods on a concrete subclass raise NotImplementedError.""" + d = dummy_engine_class() + with pytest.raises(NotImplementedError): + d.load_scene("x") + with pytest.raises(NotImplementedError): + d.run_policy("x") + with pytest.raises(NotImplementedError): + d.randomize() + with pytest.raises(NotImplementedError): + d.get_contacts() + + def test_context_manager_calls_cleanup(self, dummy_engine_class): + """ABC supports context manager protocol and calls cleanup on exit.""" + cleaned = {"flag": False} + + class Cleanable(dummy_engine_class): # type: ignore[misc,valid-type] + def cleanup(self) -> None: + cleaned["flag"] = True + + with Cleanable(): + pass + assert cleaned["flag"] is True + + +# ── Factory Tests ──────────────────────────────────────────────── + + +class TestSimulationFactory: + """Test backend registration and creation — full round-trip.""" + + def test_list_backends_includes_mujoco(self): + backends = list_backends() + assert "mujoco" in backends + + def test_register_create_and_use_backend(self, dummy_engine_class): + """Register a custom backend, create it via factory, verify instance.""" + register_backend("fake_test", lambda: dummy_engine_class, force=True) + assert "fake_test" in list_backends() + sim = create_simulation("fake_test") + assert isinstance(sim, dummy_engine_class) + + def test_register_rejects_duplicate(self, dummy_engine_class): + """Registering an existing name without force raises ValueError.""" + register_backend("dup_test", lambda: dummy_engine_class, force=True) + with pytest.raises(ValueError, match="already registered"): + register_backend("dup_test", lambda: dummy_engine_class) + + def test_register_rejects_builtin_alias_in_aliases(self, dummy_engine_class): + """Cannot use a built-in alias as a new backend's alias.""" + with pytest.raises(ValueError, match="conflicts with built-in"): + register_backend("custom_phys", lambda: dummy_engine_class, aliases=["mj"]) + + # ── Regression tests for alias-shadowing bug (PR #84 review) ── + + def test_register_rejects_builtin_alias_as_name(self, dummy_engine_class): + """Cannot register a new backend under a built-in alias name. + + Regression test for the bug where ``register_backend("mj", loader)`` + succeeded without ``force=True`` because the conflict check only + looked at ``_BUILTIN_BACKENDS`` and ``_runtime_registry``, missing + ``_BUILTIN_ALIASES``. + """ + for builtin_alias in ("mj", "mjc", "mjx"): + with pytest.raises(ValueError, match="conflicts with built-in alias"): + register_backend(builtin_alias, lambda: dummy_engine_class) + + def test_register_rejects_runtime_alias_as_name(self, dummy_engine_class): + """Cannot register a new backend under a runtime-registered alias name.""" + register_backend("backend_a", lambda: dummy_engine_class, aliases=["short_a"], force=True) + with pytest.raises(ValueError, match="conflicts with runtime alias"): + register_backend("short_a", lambda: dummy_engine_class) + + def test_register_rejects_backend_name_as_alias(self, dummy_engine_class): + """Cannot use an existing backend name as a new backend's alias.""" + with pytest.raises(ValueError, match="conflicts with existing backend name"): + register_backend("new_x", lambda: dummy_engine_class, aliases=["mujoco"]) + + def test_register_force_overrides_alias_conflict(self, dummy_engine_class): + """force=True bypasses all conflict checks (escape hatch).""" + # Should NOT raise + register_backend("mj", lambda: dummy_engine_class, force=True) + # Clean up — put the real mj alias back by re-importing + import importlib + + from strands_robots.simulation import factory + + importlib.reload(factory) + + +# ── Model Registry Tests ───────────────────────────────────────── + + +class TestModelRegistry: + """Test URDF/MJCF model resolution.""" + + def test_list_available_models_returns_robot_table(self): + from strands_robots.simulation.model_registry import list_available_models + + models = list_available_models() + assert isinstance(models, str) + assert "so100" in models + assert len(models) > 100 + + def test_register_and_resolve_urdf(self, tmp_path): + """Register a URDF, resolve it back — full round-trip.""" + from strands_robots.simulation.model_registry import register_urdf, resolve_urdf + + urdf_file = tmp_path / "robot.urdf" + urdf_file.write_text("") + register_urdf("test_robot_xyz", str(urdf_file)) + result = resolve_urdf("test_robot_xyz") + assert result == str(urdf_file) + + def test_list_registered_urdfs(self): + from strands_robots.simulation.model_registry import list_registered_urdfs, register_urdf + + register_urdf("list_test_bot", "/fake/list.urdf") + urdfs = list_registered_urdfs() + assert isinstance(urdfs, dict) + assert "list_test_bot" in urdfs + + +# ── Dataclass Behavioral Tests ─────────────────────────────────── + + +class TestSimModelsUsage: + """Test that simulation models behave correctly in real usage patterns.""" + + def test_sim_world_tracks_robots(self): + """SimWorld can add robots and objects — simulates real world setup.""" + world = SimWorld() + robot = SimRobot(name="so100", urdf_path="/p") + world.robots["so100"] = robot + assert "so100" in world.robots + assert world.status == SimStatus.IDLE + + def test_sim_object_preserves_originals_for_randomization(self): + """SimObject stores original position/color for domain randomization reset.""" + obj = SimObject(name="ball", shape="sphere", position=[1, 2, 3], color=[1, 0, 0, 1]) + assert obj._original_position == [1, 2, 3] + assert obj._original_color == [1, 0, 0, 1] + + def test_trajectory_step_records_episode_data(self): + """TrajectoryStep captures full observation-action pair for dataset recording.""" + step = TrajectoryStep( + timestamp=1.0, + sim_time=0.5, + robot_name="arm", + observation={"state": [1, 2, 3]}, + action={"joint_0": 0.5}, + instruction="pick up cube", + ) + assert step.robot_name == "arm" + assert step.instruction == "pick up cube" + assert step.observation["state"] == [1, 2, 3] diff --git a/tests/test_user_registry.py b/tests/test_user_registry.py new file mode 100644 index 0000000..66e5690 --- /dev/null +++ b/tests/test_user_registry.py @@ -0,0 +1,396 @@ +"""Tests for user-local robot registry and shared path utilities. + +Covers: + - strands_robots.registry.user_registry (register, unregister, list, persistence) + - strands_robots.registry.loader._merge_user_robots (user overlay merge) + - strands_robots.utils (get_base_dir, get_assets_dir, resolve_asset_path) +""" + +import json +import logging +import os +from pathlib import Path +from unittest import mock + +import pytest + +from strands_robots.registry import get_robot, list_robots, resolve_name +from strands_robots.registry.user_registry import ( + _get_user_registry_path, + _invalidate_cache, + _load_user_registry, + get_user_robots, + list_user_robots, + register_robot, + unregister_robot, +) +from strands_robots.utils import get_assets_dir, get_base_dir, resolve_asset_path + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_MINIMAL_MJCF = '' + + +@pytest.fixture(autouse=True) +def _isolate_registry(tmp_path, monkeypatch): + """Point STRANDS_BASE_DIR + STRANDS_ASSETS_DIR to temp dirs for every test. + + ``STRANDS_BASE_DIR`` controls where ``user_robots.json`` lives. + ``STRANDS_ASSETS_DIR`` controls where robot asset directories live. + The two are independent — the base dir is not derived from the assets dir. + """ + assets_dir = tmp_path / "assets" + assets_dir.mkdir() + monkeypatch.setenv("STRANDS_BASE_DIR", str(tmp_path)) + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(assets_dir)) + _invalidate_cache() + yield + _invalidate_cache() + + +def _make_robot(parent: Path, name: str = "test_bot", xml_name: str = "bot.xml") -> Path: + """Create a minimal MJCF robot directory and return its path.""" + d = parent / name + d.mkdir(parents=True, exist_ok=True) + (d / xml_name).write_text(_MINIMAL_MJCF) + return d + + +# =========================================================================== +# Registration +# =========================================================================== + + +class TestRegisterRobot: + """register_robot() stores metadata and makes the robot discoverable.""" + + def test_stores_description_category_and_joints(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + entry = register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + description="A test bot", + category="arm", + joints=3, + ) + assert entry["description"] == "A test bot" + assert entry["category"] == "arm" + assert entry["joints"] == 3 + assert entry["asset"]["model_xml"] == "bot.xml" + + def test_robot_visible_via_get_robot(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + assert get_robot("test_bot") is not None + + def test_robot_visible_in_list_robots(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + assert "test_bot" in [r["name"] for r in list_robots()] + + def test_aliases_resolve_to_canonical_name(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + aliases=["my_bot", "tb"], + ) + assert resolve_name("my_bot") == "test_bot" + assert resolve_name("tb") == "test_bot" + + def test_stores_robot_descriptions_module(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + entry = register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + robot_descriptions_module="my_pkg.test_bot", + ) + assert entry["asset"]["robot_descriptions_module"] == "my_pkg.test_bot" + + def test_stores_hardware_config(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + hw = {"lerobot_type": "so100_follower", "cameras": {"top": 0}} + entry = register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + hardware=hw, + ) + assert entry["hardware"] == hw + + +class TestRegisterRobotNameNormalization: + """Names are lower-cased, stripped, and hyphens become underscores.""" + + def test_normalizes_whitespace_hyphens_and_case(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name=" My-Bot ", model_xml="bot.xml", asset_dir=str(robot_dir)) + assert get_robot("my_bot") is not None + + +class TestRegisterRobotDuplicates: + """Duplicate handling: raise by default, allow with overwrite=True.""" + + def test_duplicate_raises_by_default(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + with pytest.raises(ValueError, match="already in user registry"): + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + + def test_overwrite_replaces_existing(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir), description="v1") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir), description="v2", overwrite=True) + assert get_robot("test_bot")["description"] == "v2" + + def test_overriding_package_robot_logs_info(self, tmp_path, caplog): + """Registering a name that exists in the package registry emits an info log.""" + panda_dir = _make_robot(tmp_path / "assets", name="panda", xml_name="panda.xml") + with caplog.at_level(logging.INFO, logger="strands_robots.registry.user_registry"): + register_robot( + name="panda", + model_xml="panda.xml", + asset_dir=str(panda_dir), + description="Custom panda", + ) + assert any("exists in package registry" in m for m in caplog.messages) + assert get_robot("panda")["description"] == "Custom panda" + unregister_robot("panda") + + +class TestRegisterRobotValidation: + """register_robot rejects invalid inputs.""" + + def test_missing_model_xml_raises_file_not_found(self, tmp_path): + empty_dir = tmp_path / "assets" / "empty" + empty_dir.mkdir(parents=True) + with pytest.raises(FileNotFoundError, match="Model XML not found"): + register_robot(name="empty", model_xml="nope.xml", asset_dir=str(empty_dir)) + + +class TestRegisterRobotAssetDirResolution: + """asset_dir is resolved relative to STRANDS_ASSETS_DIR.""" + + def test_none_defaults_to_assets_subdir(self, tmp_path): + default_dir = _make_robot(tmp_path / "assets", name="auto_bot", xml_name="auto.xml") + entry = register_robot(name="auto_bot", model_xml="auto.xml") + assert entry["_user_asset_path"] == str(default_dir) + + def test_relative_path_resolved_against_assets(self, tmp_path): + rel_dir = tmp_path / "assets" / "sub" / "bot" + rel_dir.mkdir(parents=True) + (rel_dir / "r.xml").write_text(_MINIMAL_MJCF) + entry = register_robot(name="rel_bot", model_xml="r.xml", asset_dir="sub/bot") + assert entry["_user_asset_path"] == str(rel_dir) + + def test_absolute_path_used_as_is(self, tmp_path): + abs_dir = tmp_path / "elsewhere" / "bot" + abs_dir.mkdir(parents=True) + (abs_dir / "b.xml").write_text(_MINIMAL_MJCF) + entry = register_robot(name="abs_bot", model_xml="b.xml", asset_dir=str(abs_dir)) + assert entry["_user_asset_path"] == str(abs_dir) + + +# =========================================================================== +# Unregistration +# =========================================================================== + + +class TestUnregisterRobot: + """unregister_robot() removes from user registry only.""" + + def test_removes_registered_robot(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + assert unregister_robot("test_bot") is True + assert get_user_robots().get("test_bot") is None + + def test_returns_false_for_nonexistent(self): + assert unregister_robot("nonexistent") is False + + def test_does_not_affect_package_robots(self): + assert get_robot("panda") is not None + assert unregister_robot("panda") is False + assert get_robot("panda") is not None + + +# =========================================================================== +# Listing +# =========================================================================== + + +class TestListUserRobots: + """list_user_robots() returns user-registered robots only.""" + + def test_empty_when_nothing_registered(self): + assert list_user_robots() == [] + + def test_returns_registered_robot_metadata(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + description="Desc", + joints=5, + ) + result = list_user_robots() + assert len(result) == 1 + assert result[0]["name"] == "test_bot" + assert result[0]["description"] == "Desc" + assert result[0]["joints"] == 5 + assert result[0]["model_xml"] == "bot.xml" + + +# =========================================================================== +# Persistence +# =========================================================================== + + +class TestPersistence: + """User registry persists to a JSON file and survives corruption.""" + + def test_writes_json_file(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + path = _get_user_registry_path() + assert path.exists() + data = json.loads(path.read_text()) + assert "test_bot" in data["robots"] + + def test_corrupted_json_returns_empty(self): + path = _get_user_registry_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("NOT JSON!!!") + assert _load_user_registry() == {"robots": {}} + + def test_valid_json_without_robots_key_returns_empty(self): + path = _get_user_registry_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text('{"version": 1}') + assert _load_user_registry() == {"robots": {}} + + +# =========================================================================== +# Loader merge +# =========================================================================== + + +class TestLoaderMerge: + """_merge_user_robots gracefully handles missing user_registry module.""" + + def test_import_error_returns_data_unchanged(self): + from strands_robots.registry.loader import _merge_user_robots + + data = {"robots": {"fake": {"description": "test"}}} + with mock.patch.dict("sys.modules", {"strands_robots.registry.user_registry": None}): + result = _merge_user_robots(data) + assert "fake" in result["robots"] + + +# =========================================================================== +# STRANDS_BASE_DIR integration +# =========================================================================== + + +class TestStrandsBaseDirIntegration: + """Registry file location respects STRANDS_BASE_DIR env var. + + STRANDS_ASSETS_DIR intentionally does NOT move the registry — it only + controls where asset directories live. See utils.get_base_dir() docstring. + """ + + def test_registry_file_lives_in_base_dir(self, tmp_path): + custom = tmp_path / "custom_base" + custom.mkdir() + with mock.patch.dict(os.environ, {"STRANDS_BASE_DIR": str(custom)}, clear=False): + assert _get_user_registry_path().parent == custom + + def test_assets_dir_does_not_move_registry(self, tmp_path, monkeypatch): + """Setting only STRANDS_ASSETS_DIR must not change the registry location.""" + monkeypatch.delenv("STRANDS_BASE_DIR", raising=False) + custom_assets = tmp_path / "custom_assets" + custom_assets.mkdir() + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(custom_assets)) + # Registry should land under the default base, not the assets dir. + assert ".strands_robots" in str(_get_user_registry_path()) + + def test_defaults_to_dot_strands_robots(self, monkeypatch): + monkeypatch.delenv("STRANDS_BASE_DIR", raising=False) + monkeypatch.delenv("STRANDS_ASSETS_DIR", raising=False) + assert ".strands_robots" in str(_get_user_registry_path()) + + +# =========================================================================== +# Path utilities (strands_robots.utils) +# =========================================================================== + + +class TestGetAssetsDir: + """get_assets_dir() returns STRANDS_ASSETS_DIR or ~/.strands_robots/assets/.""" + + def test_default(self, monkeypatch): + monkeypatch.delenv("STRANDS_ASSETS_DIR", raising=False) + result = get_assets_dir() + assert str(result).endswith("assets") + assert ".strands_robots" in str(result) + + def test_custom(self, tmp_path, monkeypatch): + custom = tmp_path / "my_assets" + custom.mkdir() + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(custom)) + assert get_assets_dir() == custom + + +class TestGetBaseDir: + """get_base_dir() returns STRANDS_BASE_DIR or ~/.strands_robots/. + + It is independent of STRANDS_ASSETS_DIR by design — the base dir holds + user metadata (user_robots.json) and should not move just because the + user repoints the asset cache. + """ + + def test_default(self, monkeypatch): + monkeypatch.delenv("STRANDS_BASE_DIR", raising=False) + monkeypatch.delenv("STRANDS_ASSETS_DIR", raising=False) + assert str(get_base_dir()).endswith(".strands_robots") + + def test_custom(self, tmp_path, monkeypatch): + custom = tmp_path / "custom_base" + custom.mkdir() + monkeypatch.setenv("STRANDS_BASE_DIR", str(custom)) + assert get_base_dir() == custom + + def test_assets_dir_does_not_move_base(self, tmp_path, monkeypatch): + """STRANDS_ASSETS_DIR must not affect get_base_dir().""" + monkeypatch.delenv("STRANDS_BASE_DIR", raising=False) + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(tmp_path / "assets")) + assert str(get_base_dir()).endswith(".strands_robots") + + +class TestResolveAssetPath: + """resolve_asset_path() resolves None, relative, absolute, and ~ paths.""" + + def test_none_returns_assets_dir_plus_default_name(self, tmp_path, monkeypatch): + assets = tmp_path / "a" + assets.mkdir() + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(assets)) + assert resolve_asset_path(None, "robot") == assets / "robot" + + def test_relative_resolved_against_assets_dir(self, tmp_path, monkeypatch): + assets = tmp_path / "a" + assets.mkdir() + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(assets)) + assert resolve_asset_path("sub/dir") == assets / "sub" / "dir" + + def test_absolute_path_unchanged(self): + assert resolve_asset_path("/absolute/path") == Path("/absolute/path") + + def test_tilde_expanded(self): + result = resolve_asset_path("~/robots") + assert str(result).startswith(str(Path.home())) diff --git a/tests/test_utils.py b/tests/test_utils.py index fb9baf3..8ecf078 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -49,3 +49,57 @@ def test_dotted_module(self): """Should handle dotted module names like os.path.""" mod = require_optional("os.path") assert hasattr(mod, "join") + + +# ── safe_join / get_search_paths tests (added for PR #84 follow-up) ── + + +class TestSafeJoin: + """Tests for the centralised path-traversal guard.""" + + def test_joins_clean_paths(self, tmp_path): + from strands_robots.utils import safe_join + + result = safe_join(tmp_path, "robot/model.xml") + assert result == tmp_path / "robot" / "model.xml" + + def test_rejects_traversal(self, tmp_path): + from strands_robots.utils import safe_join + + with pytest.raises(ValueError, match="Path traversal blocked"): + safe_join(tmp_path, "../etc/passwd") + + def test_rejects_absolute_escape(self, tmp_path): + from strands_robots.utils import safe_join + + with pytest.raises(ValueError, match="Path traversal blocked"): + safe_join(tmp_path, "robot/../../etc/passwd") + + def test_same_path_is_allowed(self, tmp_path): + from strands_robots.utils import safe_join + + # Empty / dot path resolves to base itself — must not raise + result = safe_join(tmp_path, ".") + assert result == tmp_path + + +class TestGetSearchPaths: + """Tests for the centralised search-path resolver.""" + + def test_returns_assets_dir_and_cwd_assets(self, tmp_path, monkeypatch): + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(tmp_path)) + monkeypatch.chdir(tmp_path) + from strands_robots.utils import get_search_paths + + paths = get_search_paths() + assert tmp_path in paths + assert (tmp_path / "assets") in paths + + def test_returns_unique_paths(self, tmp_path, monkeypatch): + # When CWD is already the assets dir, we shouldn't list the same path twice + # (deduping is explicit in the implementation). + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(tmp_path)) + from strands_robots.utils import get_search_paths + + paths = get_search_paths() + assert len(paths) == len(set(paths))