diff --git a/README.md b/README.md index 4332e4c..0a93a93 100644 --- a/README.md +++ b/README.md @@ -486,6 +486,31 @@ while True: agent.tool.gr00t_inference(action="stop", port=8000) ``` +## Configuration + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `STRANDS_ASSETS_DIR` | Custom directory for robot model assets (MJCF, meshes) | `~/.strands_robots/assets/` | +| `GROOT_API_TOKEN` | API token for GR00T inference service | — | + +### Cache Directory + +Robot model assets (MJCF XML files and meshes) are cached in: + +``` +~/.strands_robots/ +└── assets/ # Downloaded robot models (from robot_descriptions / MuJoCo Menagerie) + ├── trs_so_arm100/ + ├── franka_emika_panda/ + └── ... +``` + +To clear the cache: `rm -rf ~/.strands_robots/assets/` + +To change the cache location: `export STRANDS_ASSETS_DIR=/path/to/custom/dir` + ## Contributing We welcome contributions! Please see: diff --git a/pyproject.toml b/pyproject.toml index 40382ec..f1a7090 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,9 +48,13 @@ groot-service = [ lerobot = [ "lerobot>=0.5.0,<0.6.0", ] +sim = [ + "robot_descriptions>=1.11.0,<2.0.0", +] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", + "strands-robots[sim]", ] dev = [ "pytest>=6.0,<9.0.0", @@ -124,7 +128,7 @@ ignore_missing_imports = false # Third-party libs without type stubs [[tool.mypy.overrides]] -module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*"] +module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*"] ignore_missing_imports = true # @tool decorator injects runtime signatures mypy cannot check diff --git a/strands_robots/assets/__init__.py b/strands_robots/assets/__init__.py new file mode 100644 index 0000000..ddff6f8 --- /dev/null +++ b/strands_robots/assets/__init__.py @@ -0,0 +1,42 @@ +"""Robot Asset Manager for Strands Robots Simulation. + +Assets are resolved from ``robot_descriptions`` package or downloaded from +MuJoCo Menagerie GitHub, cached in ``~/.strands_robots/assets/``. +Override with ``STRANDS_ASSETS_DIR`` env var. + +Implementation lives in ``assets/manager.py`` — this file is thin exports only. +""" + +from strands_robots.assets.manager import ( + get_assets_dir, + get_robot_info, + get_search_paths, + list_available_robots, + resolve_model_dir, + resolve_model_path, +) +from strands_robots.registry import ( + format_robot_table, + get_robot, + list_aliases, + list_robots, + list_robots_by_category, +) +from strands_robots.registry import ( + resolve_name as resolve_robot_name, +) + +__all__ = [ + "resolve_model_path", + "resolve_model_dir", + "resolve_robot_name", + "get_robot_info", + "list_available_robots", + "list_robots_by_category", + "list_aliases", + "format_robot_table", + "get_assets_dir", + "get_search_paths", + "get_robot", + "list_robots", +] diff --git a/strands_robots/assets/download.py b/strands_robots/assets/download.py new file mode 100644 index 0000000..eb64adc --- /dev/null +++ b/strands_robots/assets/download.py @@ -0,0 +1,435 @@ +"""Download robot model assets via ``robot_descriptions`` or custom GitHub repos. + +This module contains the core download logic for robot assets. +The ``strands_robots.tools.download_assets`` tool is a thin ``@tool`` wrapper +that delegates to :func:`download_robots` here. + +Strategy (in order of preference): + 1. ``robot_descriptions`` package — recommended by MuJoCo Menagerie. + 2. Shallow ``git clone`` fallback for Menagerie robots. + 3. Custom GitHub repos for non-Menagerie robots. + +Assets are cached in ``~/.strands_robots/assets/`` (override with +``STRANDS_ASSETS_DIR``). Install the optional dependency:: + + pip install strands-robots[sim-mujoco] # includes robot_descriptions +""" + +from __future__ import annotations + +import importlib +import logging +import os +import re +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any + +from ..registry import get_robot +from ..registry import list_robots as registry_list_robots +from ..registry import resolve_name as resolve_robot_name +from .manager import get_search_paths + +logger = logging.getLogger(__name__) + +MENAGERIE_REPO = "https://github.com/google-deepmind/mujoco_menagerie.git" + +# Only HTTPS GitHub URLs are allowed for cloning. +_ALLOWED_CLONE_URL_RE = re.compile(r"^https://github\.com/[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+\.git$") + + +# ── robot_descriptions integration ──────────────────────────────────── + + +def _robot_descriptions_available() -> bool: + """Check if ``robot_descriptions`` is installed.""" + try: + import robot_descriptions # type: ignore[import-not-found] # noqa: F401 + + return True + except ImportError: + return False + + +def _resolve_robot_descriptions_module(name: str, info: dict) -> str | None: + """Resolve the ``robot_descriptions`` module name for a robot. + + Uses the ``robot_descriptions_module`` field from the registry (O(1)), + with a lightweight naming-convention fallback for unregistered robots. + + Args: + name: Canonical robot name. + info: Robot registry entry. + + Returns: + Module name (e.g. ``panda_mj_description``) or ``None``. + """ + # Primary: explicit registry entry (preferred, O(1)) + module_name: str | None = info.get("asset", {}).get("robot_descriptions_module") + if module_name: + return str(module_name) + + # Fallback: try common naming conventions (max 3 imports) + asset_dir = info.get("asset", {}).get("dir", "") + candidates = [ + f"{asset_dir}_mj_description", + f"{name}_mj_description", + f"{name}_description", + ] + for candidate in candidates: + if not re.match(r"^[a-z0-9_]+$", candidate): + continue + try: + importlib.import_module(f"robot_descriptions.{candidate}") + logger.warning( + "Resolved '%s' via naming heuristic → '%s'. " + "Consider adding 'robot_descriptions_module' to the registry.", + name, + candidate, + ) + return candidate + except ImportError: + continue + + return None + + +# ── Helpers ─────────────────────────────────────────────────────────── + + +def get_user_assets_dir() -> Path: + """Get user-level asset cache directory.""" + custom = os.getenv("STRANDS_ASSETS_DIR") + directory = Path(custom) if custom else Path.home() / ".strands_robots" / "assets" + directory.mkdir(parents=True, exist_ok=True) + return directory + + +def _safe_join(base: Path, untrusted: str) -> Path: + """Join *base* with an untrusted relative path, rejecting traversal.""" + joined = Path(os.path.normpath(base / untrusted)) + base_norm = Path(os.path.normpath(base)) + if not (joined == base_norm or str(joined).startswith(str(base_norm) + os.sep)): + raise ValueError(f"Path traversal blocked: {untrusted!r} escapes {base}") + return joined + + +def _needs_download(name: str, info: dict[str, Any] | None, force: bool = False) -> bool: + """Return *True* if a robot's mesh files are missing.""" + if info is None: + return False + asset = info.get("asset", {}) + if not asset: + return False + + xml_file, asset_dir = asset["model_xml"], asset["dir"] + + for search_dir in get_search_paths(): + model_path = search_dir / asset_dir / xml_file + if not model_path.exists(): + continue + try: + content = model_path.read_text() + mesh_files = re.findall(r'file="([^"]+\.(?:stl|STL|obj|OBJ|msh))"', content) + if not mesh_files: + return False + meshdir_match = re.search(r'meshdir="([^"]*)"', content) + meshdir = meshdir_match.group(1) if meshdir_match else "" + for mesh in mesh_files: + if not (model_path.parent / meshdir / mesh).exists(): + return True + return force + except Exception: + return True + + return True + + +def _get_source(info: dict[str, Any] | None) -> dict[str, Any]: + """Get download source for a robot. Defaults to ``menagerie``.""" + if info is None: + return {"type": "menagerie"} + source = info.get("asset", {}).get("source", {}) + return source if source else {"type": "menagerie"} + + +def _shallow_clone(repo_url: str, dest: str, *, timeout: int = 120) -> None: + """Shallow-clone *repo_url* into *dest*. + + Only HTTPS ``github.com`` URLs are accepted — ``ssh://``, ``git://``, + ``file://``, and other schemes are rejected to prevent command-injection + and SSRF risks. + + Raises: + ValueError: If *repo_url* does not match the allowed HTTPS GitHub pattern. + subprocess.CalledProcessError: If the ``git clone`` command fails. + subprocess.TimeoutExpired: If the clone exceeds *timeout* seconds. + """ + if not _ALLOWED_CLONE_URL_RE.match(repo_url): + raise ValueError(f"Blocked clone URL (only HTTPS github.com allowed): {repo_url!r}") + logger.info("Cloning %s (this may take a moment)...", repo_url) + subprocess.run( + ["git", "clone", "--depth", "1", repo_url, dest], + check=True, + capture_output=True, + timeout=timeout, + ) + + +def _copy_and_clean(src: Path, dst: Path) -> None: + """Copy *src* tree to *dst* and remove non-essential files.""" + shutil.copytree(str(src), str(dst), dirs_exist_ok=True) + for pattern in ("README.md", "LICENSE", "CHANGELOG.md", "*.png", "*.jpg", ".git*"): + for path in dst.glob(pattern): + if path.is_file(): + path.unlink() + elif path.is_dir(): + shutil.rmtree(str(path), ignore_errors=True) + + +# ── Download backends ───────────────────────────────────────────────── + + +def _download_via_robot_descriptions(robots: dict[str, dict], dest_dir: Path) -> dict[str, str]: + """Download robots using the ``robot_descriptions`` package. + + Imports only the specific module for each robot (O(1) per robot), + using the ``robot_descriptions_module`` field from the registry. + The import triggers the upstream clone on first use, then we symlink + ``PACKAGE_PATH`` into our asset cache. + """ + results: dict[str, str] = {} + if not robots: + return results + + for name, info in robots.items(): + asset_dir = info["asset"]["dir"] + module_name = _resolve_robot_descriptions_module(name, info) + if module_name is None: + results[name] = "skipped: no robot_descriptions module found" + continue + if not re.match(r"^[a-z0-9_]+$", module_name): + results[name] = f"skipped: invalid module name: {module_name}" + continue + + try: + mod = importlib.import_module(f"robot_descriptions.{module_name}") + package_path = Path(mod.PACKAGE_PATH) + if not package_path.exists(): + results[name] = f"failed: PACKAGE_PATH missing: {package_path}" + continue + + dst = _safe_join(dest_dir, asset_dir) + if dst.is_symlink() and dst.resolve() == package_path.resolve(): + # Validate existing symlink still has the expected XML + expected_xml = dst / info["asset"]["model_xml"] + if expected_xml.exists(): + results[name] = "downloaded" + continue + # Stale symlink — remove and re-download via git + dst.unlink() + results[name] = f"failed: stale symlink — {info['asset']['model_xml']} not found in {package_path}" + continue + if dst.exists() or dst.is_symlink(): + dst.unlink() if dst.is_symlink() else shutil.rmtree(str(dst)) + + try: + dst.symlink_to(package_path) + except OSError: + shutil.copytree(str(package_path), str(dst), dirs_exist_ok=True) + + # Validate: expected XML must exist in the linked/copied dir + expected_xml = dst / info["asset"]["model_xml"] + if not expected_xml.exists(): + logger.warning( + "robot_descriptions module '%s' linked for %s but " + "expected XML '%s' not found — falling back to git", + module_name, + name, + info["asset"]["model_xml"], + ) + if dst.is_symlink(): + dst.unlink() + else: + shutil.rmtree(str(dst), ignore_errors=True) + results[name] = ( + f"failed: XML mismatch — module '{module_name}' does not contain {info['asset']['model_xml']}" + ) + continue + + results[name] = "downloaded" + except Exception as exc: + results[name] = f"failed: {exc}" + logger.warning("robot_descriptions failed for %s: %s", name, exc) + + return results + + +def _download_via_git(robots: dict[str, dict], dest_dir: Path) -> dict[str, str]: + """Fallback: shallow-clone Menagerie and copy robot directories.""" + results: dict[str, str] = {} + if not robots: + return results + + with tempfile.TemporaryDirectory() as tmpdir: + clone_dir = os.path.join(tmpdir, "mujoco_menagerie") + try: + _shallow_clone(MENAGERIE_REPO, clone_dir) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, ValueError) as exc: + reason = "timeout" if isinstance(exc, subprocess.TimeoutExpired) else str(exc)[:100] + return {n: f"failed: git clone {reason}" for n in robots} + + for name, info in robots.items(): + asset_dir = info["asset"]["dir"] + src = _safe_join(Path(clone_dir), asset_dir) + if not src.exists(): + results[name] = f"failed: {asset_dir} not in menagerie" + continue + try: + _copy_and_clean(src, _safe_join(dest_dir, asset_dir)) + results[name] = "downloaded" + except Exception as exc: + results[name] = f"failed: {exc}" + + return results + + +def _download_from_github(name: str, info: dict, dest_dir: Path) -> str: + """Download a robot from a custom GitHub repo (``asset.source``).""" + source = info["asset"]["source"] + repo = source["repo"] + if not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", repo): + return f"failed: invalid repo format: {repo}" + + subdir = source.get("subdir", "") + asset_dir = info["asset"]["dir"] + + with tempfile.TemporaryDirectory() as tmpdir: + clone_dir = os.path.join(tmpdir, "repo") + try: + # URL validation is enforced inside _shallow_clone itself + _shallow_clone(f"https://github.com/{repo}.git", clone_dir) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, ValueError) as exc: + reason = "timeout" if isinstance(exc, subprocess.TimeoutExpired) else str(exc)[:100] + return f"failed: git clone {reason}" + + src = Path(clone_dir) / subdir if subdir else Path(clone_dir) + if not src.exists(): + return f"failed: subdir '{subdir}' not found in {repo}" + + dst = _safe_join(dest_dir, asset_dir) + try: + _copy_and_clean(src, dst) + return "downloaded" + except Exception as exc: + return f"failed: {exc}" + + +# ── Orchestrator ────────────────────────────────────────────────────── + + +def download_robots( + names: list[str] | None = None, + category: str | None = None, + force: bool = False, +) -> dict[str, Any]: + """Download robot model assets from their respective sources. + + Strategy (in order of preference): + 1. ``robot_descriptions`` package — recommended by MuJoCo Menagerie. + 2. Shallow ``git clone`` fallback for Menagerie robots. + 3. Custom GitHub repos for non-Menagerie robots. + + Args: + names: Robot names to download (``None`` = all sim robots). + category: Filter by category (arm, humanoid, mobile, …). + force: Re-download even if present. + + Returns: + Dict with downloaded/skipped/failed counts, names, and details. + """ + dest_dir = get_user_assets_dir() + # Filter None values — get_robot() can return None for unknown names + all_sim: dict[str, dict[str, Any]] = { + r["name"]: info for r in registry_list_robots(mode="sim") if (info := get_robot(r["name"])) is not None + } + + # Resolve requested robots + if names: + robots: dict[str, dict[str, Any]] = {} + for name in names: + canonical = resolve_robot_name(name) + if canonical in all_sim: + robots[canonical] = all_sim[canonical] + else: + logger.warning("Unknown robot: %s (resolved: %s)", name, canonical) + elif category: + robots = {n: i for n, i in all_sim.items() if i.get("category") == category} + else: + robots = dict(all_sim) + + if not robots: + return {"downloaded": 0, "skipped": 0, "failed": 0, "message": "No matching robots found."} + + # Partition: needs download vs already present + to_download: dict[str, dict[str, Any]] = {} + skipped: list[str] = [] + for name, info in robots.items(): + if _needs_download(name, info, force): + to_download[name] = info + else: + skipped.append(name) + + if not to_download: + return { + "downloaded": 0, + "skipped": len(skipped), + "failed": 0, + "skipped_names": skipped, + "message": f"All {len(robots)} robots already have assets. Use force=True to re-download.", + } + + # Partition by source type + menagerie_robots: dict[str, Any] = {} + github_robots: dict[str, Any] = {} + for name, info in to_download.items(): + source = _get_source(info) + bucket = github_robots if source["type"] == "github" else menagerie_robots + bucket[name] = info + + # Download Menagerie robots (robot_descriptions → git fallback) + results: dict[str, str] = {} + if menagerie_robots: + if _robot_descriptions_available(): + results.update(_download_via_robot_descriptions(menagerie_robots, dest_dir)) + # Retry failures with git clone + retry = { + n: menagerie_robots[n] for n, r in results.items() if r.startswith("failed") or r.startswith("skipped") + } + if retry: + results.update(_download_via_git(retry, dest_dir)) + else: + results.update(_download_via_git(menagerie_robots, dest_dir)) + + # Download custom GitHub robots + for name, info in github_robots.items(): + results[name] = _download_from_github(name, info, dest_dir) + + downloaded = [n for n, r in results.items() if r == "downloaded"] + failed = {n: r for n, r in results.items() if r != "downloaded"} + method = "robot_descriptions" if _robot_descriptions_available() else "git clone" + + return { + "downloaded": len(downloaded), + "skipped": len(skipped), + "failed": len(failed), + "downloaded_names": downloaded, + "skipped_names": skipped, + "failed_names": list(failed), + "failed_details": failed, + "assets_dir": str(dest_dir), + "method": method, + "message": (f"{len(downloaded)} downloaded ({method}), {len(skipped)} already present, {len(failed)} failed."), + } diff --git a/strands_robots/assets/manager.py b/strands_robots/assets/manager.py new file mode 100644 index 0000000..cc77ece --- /dev/null +++ b/strands_robots/assets/manager.py @@ -0,0 +1,290 @@ +"""Robot Asset Manager for Strands Robots Simulation. + +Resolves robot model files (MJCF XML) from: + 1. ``STRANDS_ASSETS_DIR`` env var (user override) + 2. User cache (``~/.strands_robots/assets/``) + 3. ``robot_descriptions`` package (MuJoCo Menagerie) + 4. Project-local ``./assets/`` +""" + +import logging +import os +from pathlib import Path + +from strands_robots.registry import ( + get_robot, + list_robots, +) +from strands_robots.registry import ( + resolve_name as resolve_robot_name, +) +from strands_robots.utils import get_assets_dir + +logger = logging.getLogger(__name__) + + +# ───────────────────────────────────────────────────────────────────── +# Path safety +# ───────────────────────────────────────────────────────────────────── + + +def _safe_join(base: Path, untrusted: str) -> Path: + """Join *base* with an untrusted relative path, rejecting traversal. + + Raises: + ValueError: If the resulting path escapes *base*. + """ + joined = Path(os.path.normpath(base / untrusted)) + base_norm = Path(os.path.normpath(base)) + if not (joined == base_norm or str(joined).startswith(str(base_norm) + os.sep)): + raise ValueError(f"Path traversal blocked: {untrusted!r} escapes {base}") + return joined + + +# ───────────────────────────────────────────────────────────────────── +# Asset directory resolution +# ───────────────────────────────────────────────────────────────────── + + +def get_search_paths() -> list[Path]: + """Get ordered list of asset search paths. + + Order (local assets take priority over defaults): + 1. User asset dir (``STRANDS_ASSETS_DIR`` or ``~/.strands_robots/assets/``) + 2. CWD/assets (project-local) + + Note: + ``STRANDS_ASSETS_DIR`` handling is centralised in + :func:`strands_robots.utils.get_assets_dir` — no need to read + the env var again here. + """ + paths: list[Path] = [] + + # User asset dir (respects STRANDS_ASSETS_DIR if set) + user_cache = get_assets_dir() + if user_cache not in paths: + paths.append(user_cache) + + # CWD/assets (project-local) + cwd_assets = Path.cwd() / "assets" + if cwd_assets not in paths: + paths.append(cwd_assets) + + return paths + + +# ───────────────────────────────────────────────────────────────────── +# Model path resolution (delegates to registry) +# ───────────────────────────────────────────────────────────────────── + + +def _auto_download_robot(name: str, info: dict) -> bool: + """Auto-download a single robot's assets via robot_descriptions. + + Called lazily when resolve_model_path finds XML but no meshes. + Returns True if download succeeded. + """ + try: + # Lazy import: avoids circular import (manager ↔ download) at module level. + # download.py depends on optional robot_descriptions package. + from .download import ( + _download_from_github, + _download_via_robot_descriptions, + _robot_descriptions_available, + get_user_assets_dir, + ) + except ImportError: + logger.warning("Auto-download unavailable: install robot_descriptions for automatic asset downloads") + return False + + dest_dir = get_user_assets_dir() + canonical = resolve_robot_name(name) + + # Try robot_descriptions first (covers most robots) + if _robot_descriptions_available(): + results = _download_via_robot_descriptions({canonical: info}, dest_dir) + if results.get(canonical, "").startswith("downloaded"): + logger.info("Auto-downloaded %s via robot_descriptions", canonical) + return True + + # Try custom GitHub source + source = info.get("asset", {}).get("source", {}) + if source.get("type") == "github": + result = _download_from_github(canonical, info, dest_dir) + if result.startswith("downloaded"): + logger.info("Auto-downloaded %s from GitHub", canonical) + return True + + return False + + +def _has_meshes(directory: Path) -> bool: + """Check if a directory tree contains mesh files.""" + _MESH_EXTS = {".stl", ".obj", ".msh", ".ply"} + return any(f.suffix.lower() in _MESH_EXTS for f in directory.rglob("*") if f.is_file()) + + +def _resolve_candidates(asset_dir_name: str, xml_file: str, name: str) -> list[Path]: + """Resolve candidate paths for a robot XML, with path-traversal protection. + + Uses ``_safe_join`` to prevent ``../`` in registry-sourced ``asset_dir_name`` + or ``xml_file`` from escaping the search directories. + """ + candidates: list[Path] = [] + for search_dir in get_search_paths(): + try: + model_path = _safe_join(search_dir, f"{asset_dir_name}/{xml_file}") + except ValueError: + logger.warning("Path traversal attempt blocked for robot: %s", name) + return [] + if model_path.exists(): + candidates.append(model_path) + return candidates + + +def resolve_model_path( + name: str, + prefer_scene: bool = False, +) -> Path | None: + """Resolve a robot name to its MJCF model XML path. + + Looks up the robot in ``registry/robots.json``, then searches + the asset directories for the actual file. If XML is found but + mesh files are missing, automatically downloads them via + ``robot_descriptions`` before returning. + + Args: + name: Robot name (canonical or alias). + prefer_scene: If True, return scene XML (with ground/lights) + instead of bare model XML. + + Returns: + Path to the MJCF XML file, or None if not found. + + Examples:: + + resolve_model_path("so100") # → .../trs_so_arm100/so_arm100.xml + resolve_model_path("so100", prefer_scene=True) # → .../trs_so_arm100/scene.xml + resolve_model_path("franka") # → .../franka_emika_panda/panda.xml + """ + info = get_robot(name) + if not info or "asset" not in info: + logger.warning("Unknown robot or no asset: %s", name) + return None + + asset = info["asset"] + # Explicit str() casts: dict subscript returns Any, but Path / Any → Any + xml_file: str = str(asset["scene_xml"] if prefer_scene else asset["model_xml"]) + asset_dir_name: str = str(asset["dir"]) + + candidates: list[Path] = [] + + # Check user-registered asset path first (highest priority) + user_path = info.get("_user_asset_path") + if user_path: + user_model = Path(user_path) / xml_file + if user_model.exists(): + candidates.append(user_model) + + # Search standard paths with traversal protection + candidates.extend(_resolve_candidates(asset_dir_name, xml_file, name)) + + if not candidates: + # No XML found at all — try auto-download, then re-search + logger.info("No XML found for %s, attempting auto-download...", name) + if _auto_download_robot(name, info): + candidates.extend(_resolve_candidates(asset_dir_name, xml_file, name)) + + if not candidates: + logger.warning("Robot model not found: %s → %s/%s", name, asset_dir_name, xml_file) + return None + + # Prefer the candidate whose directory contains mesh files, + # because an XML without meshes will fail to load in MuJoCo. + for path in candidates: + if _has_meshes(path.parent): + logger.debug("Resolved %s → %s (has meshes)", name, path) + return Path(path) + + # XML found but no meshes — auto-download and re-check + logger.info("XML found for %s but no meshes, attempting auto-download...", name) + if _auto_download_robot(name, info): + # Re-scan after download (new symlinks may have appeared) + refreshed = _resolve_candidates(asset_dir_name, xml_file, name) + for path in refreshed: + if _has_meshes(path.parent): + logger.debug("Resolved %s → %s (auto-downloaded)", name, path) + return Path(path) + + # Final fallback: return first candidate (some robots have no meshes) + logger.debug("Resolved %s → %s (no meshes available)", name, candidates[0]) + return Path(candidates[0]) + + +def resolve_model_dir(name: str) -> Path | None: + """Resolve a robot name to its asset directory (containing XML + meshes). + + Args: + name: Robot name (canonical or alias). + + Returns: + Path to the robot's asset directory, or None if not found. + """ + info = get_robot(name) + if not info or "asset" not in info: + return None + + asset_dir: str = str(info["asset"]["dir"]) + for search_dir in get_search_paths(): + try: + dir_path = _safe_join(search_dir, asset_dir) + except ValueError: + logger.warning("Path traversal attempt blocked in resolve_model_dir: %s", asset_dir) + return None + if dir_path.exists(): + return Path(dir_path) + return None + + +def get_robot_info(name: str) -> dict | None: + """Get information about a robot model. + + Args: + name: Robot name (canonical or alias). + + Returns: + Dict with description, category, joints, asset info, etc. + """ + info = get_robot(name) + if info is None: + return None + result = dict(info) + result["canonical_name"] = resolve_robot_name(name) + path = resolve_model_path(name) + result["resolved_path"] = str(path) if path else None + result["available"] = path is not None + return result + + +def list_available_robots() -> list[dict]: + """List all available robot models with their info. + + Returns: + List of dicts with name, description, joints, category, available, path. + """ + robots = [] + for r in list_robots(mode="sim"): + path = resolve_model_path(r["name"]) + info = get_robot(r["name"]) or {} + robots.append( + { + "name": r["name"], + "description": r.get("description", ""), + "joints": r.get("joints"), + "category": r.get("category", ""), + "dir": info.get("asset", {}).get("dir", ""), + "available": path is not None, + "path": str(path) if path else None, + } + ) + return robots diff --git a/strands_robots/registry/__init__.py b/strands_robots/registry/__init__.py index 3430ce1..2d6ba3d 100644 --- a/strands_robots/registry/__init__.py +++ b/strands_robots/registry/__init__.py @@ -29,7 +29,7 @@ policies.json ← policy providers (shorthands/urls inside each entry) """ -from .loader import reload +from .loader import invalidate_cache, reload from .policies import ( build_policy_kwargs, get_policy_provider, @@ -48,6 +48,11 @@ list_robots_by_category, resolve_name, ) +from .user_registry import ( + list_user_robots, + register_robot, + unregister_robot, +) __all__ = [ # Robot registry @@ -66,6 +71,11 @@ "resolve_policy", "import_policy_class", "build_policy_kwargs", + # User-local registry + "register_robot", + "unregister_robot", + "list_user_robots", # Utilities "reload", + "invalidate_cache", ] diff --git a/strands_robots/registry/loader.py b/strands_robots/registry/loader.py index 188271d..99da64e 100644 --- a/strands_robots/registry/loader.py +++ b/strands_robots/registry/loader.py @@ -35,6 +35,11 @@ def _load(name: str) -> dict: if name not in _cache or _mtimes.get(name) != mtime: with open(path, encoding="utf-8") as f: data = json.load(f) + + # Merge user-local robot registry (overlay on top of package JSON) + if name == "robots": + data = _merge_user_robots(data) + _validate(name, data) _cache[name] = data _mtimes[name] = mtime @@ -43,6 +48,29 @@ def _load(name: str) -> dict: return _cache[name] +def _merge_user_robots(data: dict) -> dict: + """Merge user-local robot registry on top of package robots.json. + + User entries override package entries on name collision. + """ + try: + from .user_registry import get_user_robots + except ImportError: + return data + + user_robots = get_user_robots() + if not user_robots: + return data + + merged = dict(data) + merged_robots = dict(merged.get("robots", {})) + merged_robots.update(user_robots) + merged["robots"] = merged_robots + + logger.debug("Merged %d user-registered robot(s) into registry", len(user_robots)) + return merged + + def _validate(name: str, data: dict) -> None: """Validate uniqueness constraints after loading a registry file. @@ -103,3 +131,17 @@ def reload() -> None: """Force-reload all registry files (clears mtime cache).""" _cache.clear() _mtimes.clear() + + +def invalidate_cache(name: str | None = None) -> None: + """Invalidate cached registry data, forcing a reload on next access. + + Args: + name: Registry name to invalidate (e.g. "robots"). If None, clears all. + """ + if name is None: + _cache.clear() + _mtimes.clear() + else: + _cache.pop(name, None) + _mtimes.pop(name, None) diff --git a/strands_robots/registry/robots.json b/strands_robots/registry/robots.json index 5c552b2..be2e203 100644 --- a/strands_robots/registry/robots.json +++ b/strands_robots/registry/robots.json @@ -1,571 +1,966 @@ { - "robots": { - "so100": { - "description": "TrossenRobotics SO-ARM100 (6-DOF, Feetech servos)", - "category": "arm", - "joints": 13, - "asset": { - "dir": "trs_so_arm100", - "model_xml": "so_arm100.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "so_arm100_mj_description" - }, - "hardware": { - "lerobot_type": "so100_follower" - }, - "legacy_urdf": "so100/so100.urdf", - "aliases": [ - "so100_4cam", - "so100_dualcam", - "so100_follower", - "so_arm100", - "trs_so_arm100" - ] - }, - "so101": { - "description": "RobotStudio SO-101 (6-DOF, upgraded SO-100)", - "category": "arm", - "joints": 9, - "asset": { - "dir": "robotstudio_so101", - "model_xml": "so101.xml", - "scene_xml": "scene_box.xml", - "robot_descriptions_module": "so_arm101_mj_description" - }, - "hardware": { - "lerobot_type": "so101_follower" - }, - "aliases": [ - "robotstudio_so101", - "so101_dualcam", - "so101_follower", - "so101_leader", - "so101_tricam" - ] - }, - "koch": { - "description": "Koch v1.1 Low Cost Robot Arm (6-DOF, Dynamixel)", - "category": "arm", - "joints": 7, - "asset": { - "dir": "low_cost_robot_arm", - "model_xml": "low_cost_robot_arm.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "low_cost_robot_arm_mj_description" - }, - "hardware": { - "lerobot_type": "koch_follower" - }, - "aliases": [ - "koch_follower", - "koch_v1.1", - "low_cost_robot_arm" - ] - }, - "panda": { - "description": "Franka Emika Panda (7-DOF + gripper)", - "category": "arm", - "joints": 7, - "asset": { - "dir": "franka_emika_panda", - "model_xml": "panda.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "panda_mj_description" - }, - "legacy_urdf": "panda/panda.urdf", - "aliases": [ - "bimanual_panda_gripper", - "bimanual_panda_hand", - "franka", - "franka_emika_panda", - "franka_panda", - "libero_panda", - "oxe_droid", - "single_panda_gripper" - ] - }, - "fr3": { - "description": "Franka Research 3 (7-DOF + gripper)", - "category": "arm", - "joints": 8, - "asset": { - "dir": "franka_fr3", - "model_xml": "fr3.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "fr3_mj_description" - }, - "aliases": [ - "franka_fr3", - "franka_fr3_v2" - ] - }, - "ur5e": { - "description": "Universal Robots UR5e (6-DOF industrial)", - "category": "arm", - "joints": 8, - "asset": { - "dir": "universal_robots_ur5e", - "model_xml": "ur5e.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "ur5e_mj_description" - } - }, - "kuka_iiwa": { - "description": "KUKA LBR iiwa 14 (7-DOF collaborative)", - "category": "arm", - "joints": 11, - "asset": { - "dir": "kuka_iiwa_14", - "model_xml": "iiwa14.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "iiwa14_mj_description" - }, - "aliases": [ - "kuka_iiwa_14" - ] - }, - "kinova_gen3": { - "description": "Kinova Gen3 (7-DOF lightweight)", - "category": "arm", - "joints": 7, - "asset": { - "dir": "kinova_gen3", - "model_xml": "gen3.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "gen3_mj_description" - } - }, - "xarm7": { - "description": "UFactory xArm 7 (7-DOF + gripper)", - "category": "arm", - "joints": 13, - "asset": { - "dir": "ufactory_xarm7", - "model_xml": "xarm7.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "xarm7_mj_description" - }, - "aliases": [ - "ufactory_xarm7" - ] - }, - "vx300s": { - "description": "Trossen ViperX 300s (6-DOF + gripper)", - "category": "arm", - "joints": 19, - "asset": { - "dir": "trossen_vx300s", - "model_xml": "vx300s.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "viper_mj_description" - }, - "aliases": [ - "oxe_widowx", - "trossen_vx300s", - "viper_x300s" - ] - }, - "arx_l5": { - "description": "ARX L5 (6-DOF lightweight arm)", - "category": "arm", - "joints": 11, - "asset": { - "dir": "arx_l5", - "model_xml": "arx_l5.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "arx_l5_mj_description" - } - }, - "piper": { - "description": "AgileX Piper (6-DOF + gripper)", - "category": "arm", - "joints": 11, - "asset": { - "dir": "agilex_piper", - "model_xml": "piper.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "piper_mj_description" - }, - "aliases": [ - "agilex_piper" - ] - }, - "z1": { - "description": "Unitree Z1 (6-DOF + gripper)", - "category": "arm", - "joints": 8, - "asset": { - "dir": "unitree_z1", - "model_xml": "z1_gripper.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "z1_mj_description" - }, - "aliases": [ - "unitree_z1" - ] - }, - "openarm": { - "description": "Enactic OpenArm (7-DOF, DAMIAO motors, CAN bus)", - "category": "arm", - "joints": 9, - "asset": { - "dir": "enactic_openarm", - "model_xml": "openarm.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "openarm_v1_mj_description" - }, - "aliases": [ - "enactic_openarm", - "open_arm", - "openarm_v10" - ] - }, - "aloha": { - "description": "ALOHA Bimanual (2x ViperX 300s, 14-DOF + 2 grippers)", - "category": "bimanual", - "joints": 28, - "asset": { - "dir": "aloha", - "model_xml": "aloha.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "aloha_mj_description" - }, - "hardware": { - "lerobot_type": "bi_so_follower" - }, - "aliases": [ - "agibot_dual_arm", - "agibot_dual_arm_dexhand", - "agibot_dual_arm_full", - "agibot_dual_arm_gripper", - "agibot_genie1", - "bi_so_follower", - "galaxea_r1_pro" - ] - }, - "trossen_wxai": { - "description": "Trossen WidowX AI Bimanual", - "category": "bimanual", - "joints": 17, - "asset": { - "dir": "trossen_wxai", - "model_xml": "trossen_ai_bimanual.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "widow_mj_description" - }, - "aliases": [ - "trossen_ai_bimanual" - ] - }, - "shadow_hand": { - "description": "Shadow Dexterous Hand (24-DOF)", - "category": "hand", - "joints": 45, - "asset": { - "dir": "shadow_hand", - "model_xml": "left_hand.xml", - "scene_xml": "scene_left.xml", - "robot_descriptions_module": "shadow_hand_mj_description" - }, - "aliases": [ - "shadow_dexee" - ] - }, - "leap_hand": { - "description": "LEAP Hand (16-DOF dexterous)", - "category": "hand", - "joints": 41, - "asset": { - "dir": "leap_hand", - "model_xml": "left_hand.xml", - "scene_xml": "scene_left.xml", - "robot_descriptions_module": "leap_hand_mj_description" - } - }, - "robotiq_2f85": { - "description": "Robotiq 2F-85 Gripper (2-finger adaptive)", - "category": "hand", - "joints": 16, - "asset": { - "dir": "robotiq_2f85", - "model_xml": "2f85.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "robotiq_2f85_mj_description" - }, - "aliases": [ - "robotiq", - "robotiq_2f85_v4" - ] - }, - "fourier_n1": { - "description": "Fourier N1 / GR-1 Humanoid (26-DOF)", - "category": "humanoid", - "joints": 26, - "asset": { - "dir": "fourier_n1", - "model_xml": "n1.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "n1_mj_description" - }, - "aliases": [ - "fourier_gr1", - "fourier_gr1_arms_only", - "fourier_gr1_arms_waist", - "fourier_gr1_full_upper_body", - "gr1" - ] - }, - "unitree_g1": { - "description": "Unitree G1 Humanoid (29-DOF + dexterous hands)", - "category": "humanoid", - "joints": 46, - "asset": { - "dir": "unitree_g1", - "model_xml": "g1.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "g1_mj_description" - }, - "hardware": { - "lerobot_type": "unitree_g1" - }, - "legacy_urdf": "unitree_g1/g1.urdf", - "aliases": [ - "g1", - "g1_wbc", - "unitree_g1_full_body", - "unitree_g1_locomanip", - "unitree_g1_wbc" - ] - }, - "unitree_h1": { - "description": "Unitree H1 Humanoid (19-DOF)", - "category": "humanoid", - "joints": 20, - "asset": { - "dir": "unitree_h1", - "model_xml": "h1.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "h1_mj_description" - }, - "aliases": [ - "h1" - ] - }, - "apollo": { - "description": "Apptronik Apollo Humanoid (34-DOF)", - "category": "humanoid", - "joints": 34, - "asset": { - "dir": "apptronik_apollo", - "model_xml": "apptronik_apollo.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "apollo_mj_description" - }, - "aliases": [ - "apptronik_apollo" - ] - }, - "cassie": { - "description": "Agility Cassie Bipedal Robot", - "category": "humanoid", - "joints": 28, - "asset": { - "dir": "agility_cassie", - "model_xml": "cassie.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "cassie_mj_description" - }, - "aliases": [ - "agility_cassie" - ] - }, - "open_duck_mini": { - "description": "Open Duck Mini V2 (16-DOF expressive biped, Feetech servos)", - "category": "humanoid", - "joints": 16, - "asset": { - "dir": "open_duck_mini_v2", - "model_xml": "open_duck_mini_v2.xml", - "scene_xml": "scene.xml", - "source": { - "type": "github", - "repo": "apirrone/Open_Duck_Mini", - "subdir": "mini_bdx/robots/open_duck_mini_v2" + "robots": { + "crazyflie": { + "description": "Bitcraze Crazyflie 2 Nano-Quadcopter", + "category": "aerial", + "joints": 1, + "asset": { + "dir": "bitcraze_crazyflie_2", + "model_xml": "cf2.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "cf2_mj_description" + }, + "aliases": [ + "cf2", + "bitcraze_crazyflie" + ] + }, + "skydio_x2": { + "description": "Skydio X2 Autonomous Drone", + "category": "aerial", + "joints": 1, + "asset": { + "dir": "skydio_x2", + "model_xml": "x2.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "skydio_x2_mj_description" + } + }, + "arx_l5": { + "description": "ARX L5 (6-DOF lightweight arm)", + "category": "arm", + "joints": 11, + "asset": { + "dir": "arx_l5", + "model_xml": "arx_l5.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "arx_l5_mj_description" + } + }, + "dynamixel_2r": { + "description": "Dynamixel 2R Educational Arm (2-DOF)", + "category": "arm", + "joints": 2, + "asset": { + "dir": "dynamixel_2r", + "model_xml": "dynamixel_2r.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "dynamixel_2r_mj_description" + } + }, + "fr3": { + "description": "Franka Research 3 (7-DOF + gripper)", + "category": "arm", + "joints": 8, + "asset": { + "dir": "franka_fr3", + "model_xml": "fr3.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "fr3_mj_description" + }, + "aliases": [ + "franka_fr3" + ] + }, + "fr3_v2": { + "description": "Franka Research 3 v2 (7-DOF + gripper, updated)", + "category": "arm", + "joints": 7, + "asset": { + "dir": "franka_fr3_v2", + "model_xml": "fr3v2.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "fr3_v2_mj_description" + }, + "aliases": [ + "franka_fr3_v2" + ] + }, + "hope_jr": { + "description": "Hope Junior arm", + "category": "arm", + "hardware": { + "lerobot_type": "hope_jr" + } + }, + "kinova_gen3": { + "description": "Kinova Gen3 (7-DOF lightweight)", + "category": "arm", + "joints": 7, + "asset": { + "dir": "kinova_gen3", + "model_xml": "gen3.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "gen3_mj_description" + } + }, + "koch": { + "description": "Koch v1.1 Low Cost Robot Arm (6-DOF, Dynamixel)", + "category": "arm", + "joints": 7, + "asset": { + "dir": "low_cost_robot_arm", + "model_xml": "low_cost_robot_arm.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "low_cost_robot_arm_mj_description" + }, + "hardware": { + "lerobot_type": "koch_follower" + }, + "aliases": [ + "koch_follower", + "koch_v1.1", + "low_cost_robot_arm" + ] + }, + "kuka_iiwa": { + "description": "KUKA LBR iiwa 14 (7-DOF collaborative)", + "category": "arm", + "joints": 11, + "asset": { + "dir": "kuka_iiwa_14", + "model_xml": "iiwa14.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "iiwa14_mj_description" + }, + "aliases": [ + "kuka_iiwa_14" + ] + }, + "omx": { + "description": "OMX Robot Arm (ROBOTIS, CAN bus motors)", + "category": "arm", + "hardware": { + "lerobot_type": "omx" + }, + "aliases": [ + "omx_follower", + "omx_robot", + "robotis_omx" + ] + }, + "openarm": { + "description": "Enactic OpenArm (7-DOF, DAMIAO motors, CAN bus)", + "category": "arm", + "joints": 9, + "asset": { + "dir": "enactic_openarm", + "model_xml": "openarm.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "openarm_v1_mj_description" + }, + "aliases": [ + "enactic_openarm", + "open_arm", + "openarm_v10" + ] + }, + "panda": { + "description": "Franka Emika Panda (7-DOF + gripper)", + "category": "arm", + "joints": 7, + "asset": { + "dir": "franka_emika_panda", + "model_xml": "panda.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "panda_mj_description" + }, + "legacy_urdf": "panda/panda.urdf", + "aliases": [ + "bimanual_panda_gripper", + "bimanual_panda_hand", + "franka", + "franka_emika_panda", + "franka_panda", + "libero_panda", + "oxe_droid", + "single_panda_gripper" + ] + }, + "piper": { + "description": "AgileX Piper (6-DOF + gripper)", + "category": "arm", + "joints": 11, + "asset": { + "dir": "agilex_piper", + "model_xml": "piper.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "piper_mj_description" + }, + "aliases": [ + "agilex_piper" + ] + }, + "sawyer": { + "description": "Rethink Robotics Sawyer (7-DOF)", + "category": "arm", + "joints": 7, + "asset": { + "dir": "rethink_robotics_sawyer", + "model_xml": "sawyer.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "sawyer_mj_description" + }, + "aliases": [ + "rethink_sawyer" + ] + }, + "so100": { + "description": "TrossenRobotics SO-ARM100 (6-DOF, Feetech servos)", + "category": "arm", + "joints": 13, + "asset": { + "dir": "trs_so_arm100", + "model_xml": "so_arm100.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "so_arm100_mj_description" + }, + "hardware": { + "lerobot_type": "so100_follower" + }, + "legacy_urdf": "so100/so100.urdf", + "aliases": [ + "so100_4cam", + "so100_dualcam", + "so100_follower", + "so_arm100", + "trs_so_arm100" + ] + }, + "so101": { + "description": "RobotStudio SO-101 (6-DOF, upgraded SO-100)", + "category": "arm", + "joints": 9, + "asset": { + "dir": "robotstudio_so101", + "model_xml": "so101.xml", + "scene_xml": "scene_box.xml", + "robot_descriptions_module": "so_arm101_mj_description" + }, + "hardware": { + "lerobot_type": "so101_follower" + }, + "aliases": [ + "robotstudio_so101", + "so101_dualcam", + "so101_follower", + "so101_leader", + "so101_tricam" + ] + }, + "ur10e": { + "description": "Universal Robots UR10e (6-DOF industrial)", + "category": "arm", + "joints": 6, + "asset": { + "dir": "universal_robots_ur10e", + "model_xml": "ur10e.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "ur10e_mj_description" + } + }, + "ur5e": { + "description": "Universal Robots UR5e (6-DOF industrial)", + "category": "arm", + "joints": 8, + "asset": { + "dir": "universal_robots_ur5e", + "model_xml": "ur5e.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "ur5e_mj_description" + } + }, + "vx300s": { + "description": "Trossen ViperX 300s (6-DOF + gripper)", + "category": "arm", + "joints": 19, + "asset": { + "dir": "trossen_vx300s", + "model_xml": "vx300s.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "viper_mj_description" + }, + "aliases": [ + "oxe_widowx", + "trossen_vx300s", + "viper_x300s" + ] + }, + "wx250s": { + "description": "Trossen WidowX 250s (6-DOF + gripper)", + "category": "arm", + "joints": 16, + "asset": { + "dir": "trossen_wx250s", + "model_xml": "wx250s.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "widow_mj_description" + }, + "aliases": [ + "widowx_250s", + "trossen_wx250s" + ] + }, + "xarm7": { + "description": "UFactory xArm 7 (7-DOF + gripper)", + "category": "arm", + "joints": 13, + "asset": { + "dir": "ufactory_xarm7", + "model_xml": "xarm7.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "xarm7_mj_description" + }, + "aliases": [ + "ufactory_xarm7" + ] + }, + "yam": { + "description": "i2rt YAM Arm (8-DOF)", + "category": "arm", + "joints": 8, + "asset": { + "dir": "i2rt_yam", + "model_xml": "yam.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "yam_mj_description" + }, + "aliases": [ + "i2rt_yam" + ] + }, + "z1": { + "description": "Unitree Z1 (6-DOF + gripper)", + "category": "arm", + "joints": 8, + "asset": { + "dir": "unitree_z1", + "model_xml": "z1_gripper.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "z1_mj_description" + }, + "aliases": [ + "unitree_z1" + ] + }, + "aloha": { + "description": "ALOHA Bimanual (2x ViperX 300s, 14-DOF + 2 grippers)", + "category": "bimanual", + "joints": 28, + "asset": { + "dir": "aloha", + "model_xml": "aloha.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "aloha_mj_description" + }, + "hardware": { + "lerobot_type": "bi_so_follower" + }, + "aliases": [ + "agibot_dual_arm", + "agibot_dual_arm_dexhand", + "agibot_dual_arm_full", + "agibot_dual_arm_gripper", + "agibot_genie1", + "bi_so_follower", + "galaxea_r1_pro" + ] + }, + "bi_openarm": { + "description": "Bi-manual OpenArm (dual-arm coordination)", + "category": "bimanual", + "hardware": { + "lerobot_type": "bi_openarm" + }, + "aliases": [ + "bi_openarm_follower", + "dual_openarm", + "openarm_bimanual" + ] + }, + "trossen_wxai": { + "description": "Trossen WidowX AI Bimanual", + "category": "bimanual", + "joints": 17, + "asset": { + "dir": "trossen_wxai", + "model_xml": "trossen_ai_bimanual.xml", + "scene_xml": "scene.xml" + }, + "aliases": [ + "trossen_ai_bimanual" + ] + }, + "reachy_mini": { + "description": "Pollen Reachy Mini (6-DOF Stewart head + antennas, 9 actuators)", + "category": "expressive", + "joints": 21, + "asset": { + "dir": "reachy_mini", + "model_xml": "mjcf/reachy_mini.xml", + "scene_xml": "mjcf/scene.xml", + "source": { + "type": "github", + "repo": "pollen-robotics/reachy_mini", + "subdir": "src/reachy_mini/descriptions/reachy_mini" + } + }, + "aliases": [ + "pollen_reachy_mini", + "reachy", + "reachy-mini", + "reachymini" + ] + }, + "ability_hand": { + "description": "PSYONIC Ability Hand (5-finger prosthetic, 11-DOF)", + "category": "hand", + "joints": 11, + "asset": { + "dir": "mujoco_xml", + "model_xml": "scene.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "ability_hand_mj_description" + }, + "aliases": [ + "psyonic_ability_hand" + ] + }, + "aero_hand": { + "description": "Tetheria Aero Hand Open (16-DOF dexterous)", + "category": "hand", + "joints": 16, + "asset": { + "dir": "tetheria_aero_hand_open", + "model_xml": "left_hand.xml", + "scene_xml": "scene_left.xml", + "robot_descriptions_module": "aero_hand_open_mj_description" + }, + "aliases": [ + "tetheria_aero_hand", + "aero_hand_open" + ] + }, + "allegro_hand": { + "description": "Wonik Allegro Hand (16-DOF dexterous)", + "category": "hand", + "joints": 16, + "asset": { + "dir": "wonik_allegro", + "model_xml": "left_hand.xml", + "scene_xml": "scene_left.xml", + "robot_descriptions_module": "allegro_hand_mj_description" + }, + "aliases": [ + "wonik_allegro" + ] + }, + "leap_hand": { + "description": "LEAP Hand (16-DOF dexterous)", + "category": "hand", + "joints": 41, + "asset": { + "dir": "leap_hand", + "model_xml": "left_hand.xml", + "scene_xml": "scene_left.xml", + "robot_descriptions_module": "leap_hand_mj_description" + } + }, + "robotiq_2f85": { + "description": "Robotiq 2F-85 Gripper (2-finger adaptive)", + "category": "hand", + "joints": 16, + "asset": { + "dir": "robotiq_2f85", + "model_xml": "2f85.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "robotiq_2f85_mj_description" + }, + "aliases": [ + "robotiq" + ] + }, + "robotiq_2f85_v4": { + "description": "Robotiq 2F-85 v4 Gripper (updated model)", + "category": "hand", + "joints": 6, + "asset": { + "dir": "robotiq_2f85_v4", + "model_xml": "2f85.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "robotiq_2f85_v4_mj_description" + } + }, + "shadow_dexee": { + "description": "Shadow DexEE Dexterous End-Effector (12-DOF)", + "category": "hand", + "joints": 12, + "asset": { + "dir": "shadow_dexee", + "model_xml": "shadow_dexee.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "shadow_dexee_mj_description" + } + }, + "shadow_hand": { + "description": "Shadow Dexterous Hand (24-DOF)", + "category": "hand", + "joints": 45, + "asset": { + "dir": "shadow_hand", + "model_xml": "left_hand.xml", + "scene_xml": "scene_left.xml", + "robot_descriptions_module": "shadow_hand_mj_description" + } + }, + "adam_lite": { + "description": "PNDbotics Adam Lite Humanoid (26-DOF)", + "category": "humanoid", + "joints": 26, + "asset": { + "dir": "pndbotics_adam_lite", + "model_xml": "adam_lite.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "adam_lite_mj_description" + }, + "aliases": [ + "pndbotics_adam_lite" + ] + }, + "apollo": { + "description": "Apptronik Apollo Humanoid (34-DOF)", + "category": "humanoid", + "joints": 34, + "asset": { + "dir": "apptronik_apollo", + "model_xml": "apptronik_apollo.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "apollo_mj_description" + }, + "aliases": [ + "apptronik_apollo" + ] + }, + "asimov_v0": { + "description": "Asimov V0 Bipedal Legs (12-DOF + 2 passive toes)", + "category": "humanoid", + "joints": 15, + "asset": { + "dir": "asimov_v0", + "model_xml": "xmls/asimov.xml", + "scene_xml": "xmls/asimov.xml", + "source": { + "type": "github", + "repo": "asimovinc/asimov-v0", + "subdir": "sim-model" + } + }, + "aliases": [ + "asimov" + ] + }, + "booster_t1": { + "description": "Booster T1 Humanoid (24-DOF)", + "category": "humanoid", + "joints": 24, + "asset": { + "dir": "booster_t1", + "model_xml": "t1.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "booster_t1_mj_description" + } + }, + "cassie": { + "description": "Agility Cassie Bipedal Robot", + "category": "humanoid", + "joints": 28, + "asset": { + "dir": "agility_cassie", + "model_xml": "cassie.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "cassie_mj_description" + }, + "aliases": [ + "agility_cassie" + ] + }, + "elf2": { + "description": "BXI Elf2 Humanoid (25-DOF)", + "category": "humanoid", + "joints": 26, + "asset": { + "dir": "xml", + "model_xml": "elf2_dof25.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "elf2_mj_description" + }, + "aliases": [ + "bxi_elf2" + ] + }, + "fourier_n1": { + "description": "Fourier N1 / GR-1 Humanoid (26-DOF)", + "category": "humanoid", + "joints": 26, + "asset": { + "dir": "fourier_n1", + "model_xml": "n1.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "n1_mj_description" + }, + "aliases": [ + "fourier_gr1", + "fourier_gr1_arms_only", + "fourier_gr1_arms_waist", + "fourier_gr1_full_upper_body", + "gr1" + ] + }, + "jvrc": { + "description": "JVRC-1 Humanoid (HRP-based, 45-DOF)", + "category": "humanoid", + "joints": 45, + "asset": { + "dir": "jvrc_mj_description", + "model_xml": "xml/jvrc1.xml", + "scene_xml": "xml/jvrc1.xml", + "robot_descriptions_module": "jvrc_mj_description" + }, + "aliases": [ + "jvrc1" + ] + }, + "op3": { + "description": "ROBOTIS OP3 Humanoid (20-DOF)", + "category": "humanoid", + "joints": 21, + "asset": { + "dir": "robotis_op3", + "model_xml": "op3.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "op3_mj_description" + }, + "aliases": [ + "robotis_op3" + ] + }, + "open_duck_mini": { + "description": "Open Duck Mini V2 (16-DOF expressive biped, Feetech servos)", + "category": "humanoid", + "joints": 16, + "asset": { + "dir": "open_duck_mini_v2", + "model_xml": "open_duck_mini_v2.xml", + "scene_xml": "scene.xml", + "source": { + "type": "github", + "repo": "apirrone/Open_Duck_Mini", + "subdir": "mini_bdx/robots/open_duck_mini_v2" + } + }, + "aliases": [ + "bdx", + "mini_bdx", + "open_duck", + "open_duck_mini_v2", + "open_duck_v2" + ] + }, + "rby1": { + "description": "Rainbow Robotics RB-Y1A Mobile Manipulator (31-DOF)", + "category": "humanoid", + "joints": 31, + "asset": { + "dir": "mujoco", + "model_xml": "model.xml", + "scene_xml": "model.xml", + "robot_descriptions_module": "rby1_mj_description" + }, + "aliases": [ + "rby1a", + "rainbow_rby1" + ] + }, + "reachy2": { + "description": "Pollen Reachy 2", + "category": "humanoid", + "hardware": { + "lerobot_type": "reachy2" + } + }, + "talos": { + "description": "PAL Robotics TALOS Humanoid (32-DOF)", + "category": "humanoid", + "joints": 45, + "asset": { + "dir": "pal_talos", + "model_xml": "talos.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "talos_mj_description" + }, + "aliases": [ + "pal_talos" + ] + }, + "toddlerbot_2xc": { + "description": "Toddlerbot 2xC Humanoid (45-DOF)", + "category": "humanoid", + "joints": 45, + "asset": { + "dir": "toddlerbot_2xc", + "model_xml": "toddlerbot_2xc.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "toddlerbot_2xc_mj_description" + } + }, + "toddlerbot_2xm": { + "description": "Toddlerbot 2xM Humanoid (45-DOF)", + "category": "humanoid", + "joints": 45, + "asset": { + "dir": "toddlerbot_2xm", + "model_xml": "toddlerbot_2xm.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "toddlerbot_2xm_mj_description" + } + }, + "unitree_g1": { + "description": "Unitree G1 Humanoid (29-DOF + dexterous hands)", + "category": "humanoid", + "joints": 46, + "asset": { + "dir": "unitree_g1", + "model_xml": "g1.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "g1_mj_description" + }, + "hardware": { + "lerobot_type": "unitree_g1" + }, + "legacy_urdf": "unitree_g1/g1.urdf", + "aliases": [ + "g1", + "g1_wbc", + "unitree_g1_full_body", + "unitree_g1_locomanip", + "unitree_g1_wbc" + ] + }, + "unitree_h1": { + "description": "Unitree H1 Humanoid (19-DOF)", + "category": "humanoid", + "joints": 20, + "asset": { + "dir": "unitree_h1", + "model_xml": "h1.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "h1_mj_description" + }, + "aliases": [ + "h1" + ] + }, + "unitree_h1_2": { + "description": "Unitree H1-2 Humanoid (52-DOF, with hands)", + "category": "humanoid", + "joints": 52, + "asset": { + "dir": "h1_2_description", + "model_xml": "h1_2.xml", + "scene_xml": "h1_2.xml", + "robot_descriptions_module": "h1_2_mj_description" + }, + "aliases": [ + "h1_2" + ] + }, + "aliengo": { + "description": "Unitree Aliengo Quadruped (12-DOF)", + "category": "mobile", + "joints": 13, + "asset": { + "dir": "aliengo", + "model_xml": "xml/aliengo.xml", + "scene_xml": "xml/aliengo.xml", + "robot_descriptions_module": "aliengo_mj_description" + }, + "aliases": [ + "unitree_aliengo" + ] + }, + "anymal_b": { + "description": "ANYbotics ANYmal B Quadruped (12-DOF)", + "category": "mobile", + "joints": 13, + "asset": { + "dir": "anybotics_anymal_b", + "model_xml": "anymal_b.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "anymal_b_mj_description" + }, + "aliases": [ + "anybotics_anymal_b" + ] + }, + "anymal_c": { + "description": "ANYbotics ANYmal C Quadruped (12-DOF)", + "category": "mobile", + "joints": 13, + "asset": { + "dir": "anybotics_anymal_c", + "model_xml": "anymal_c.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "anymal_c_mj_description" + }, + "aliases": [ + "anybotics_anymal_c" + ] + }, + "earthrover": { + "description": "EarthRover Mini Plus (mobile outdoor navigation)", + "category": "mobile", + "hardware": { + "lerobot_type": "earthrover" + }, + "aliases": [ + "earth_rover", + "earthrover_mini_plus", + "frodobots" + ] + }, + "go1": { + "description": "Unitree Go1 Quadruped (12-DOF)", + "category": "mobile", + "joints": 13, + "asset": { + "dir": "unitree_go1", + "model_xml": "go1.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "go1_mj_description" + }, + "aliases": [ + "unitree_go1" + ] + }, + "lekiwi": { + "description": "LeKiwi mobile robot", + "category": "mobile", + "hardware": { + "lerobot_type": "lekiwi" + } + }, + "robot_soccer_kit": { + "description": "Robot Soccer Kit (multi-robot soccer, 65-DOF total)", + "category": "mobile", + "joints": 65, + "asset": { + "dir": "robot_soccer_kit", + "model_xml": "robot_soccer_kit.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "rsk_mj_description" + }, + "aliases": [ + "rsk" + ] + }, + "spot": { + "description": "Boston Dynamics Spot (with arm)", + "category": "mobile", + "joints": 20, + "asset": { + "dir": "boston_dynamics_spot", + "model_xml": "spot_arm.xml", + "scene_xml": "scene_arm.xml", + "robot_descriptions_module": "spot_mj_description" + }, + "aliases": [ + "boston_dynamics_spot" + ] + }, + "stretch": { + "description": "Hello Robot Stretch (original, mobile manipulator)", + "category": "mobile", + "joints": 18, + "asset": { + "dir": "hello_robot_stretch", + "model_xml": "stretch.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "stretch_mj_description" + }, + "aliases": [ + "hello_robot_stretch_original" + ] + }, + "stretch3": { + "description": "Hello Robot Stretch 3 (mobile manipulator)", + "category": "mobile", + "joints": 41, + "asset": { + "dir": "hello_robot_stretch_3", + "model_xml": "stretch.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "stretch_3_mj_description" + }, + "aliases": [ + "hello_robot_stretch", + "hello_robot_stretch_3" + ] + }, + "tiago_dual": { + "description": "PAL Robotics TIAGo++ Dual-Arm Mobile (26-DOF)", + "category": "mobile", + "joints": 26, + "asset": { + "dir": "pal_tiago_dual", + "model_xml": "tiago_dual.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "tiago++_mj_description" + }, + "aliases": [ + "tiago++", + "pal_tiago_dual" + ] + }, + "unitree_a1": { + "description": "Unitree A1 Quadruped", + "category": "mobile", + "joints": 16, + "asset": { + "dir": "unitree_a1", + "model_xml": "a1.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "a1_mj_description" + }, + "aliases": [ + "a1" + ] + }, + "unitree_go2": { + "description": "Unitree Go2 Quadruped", + "category": "mobile", + "joints": 40, + "asset": { + "dir": "unitree_go2", + "model_xml": "go2.xml", + "scene_xml": "scene.xml", + "robot_descriptions_module": "go2_mj_description" + }, + "aliases": [ + "go2" + ] + }, + "google_robot": { + "description": "Google Robot (mobile base + arm, RT-X)", + "category": "mobile_manip", + "joints": 10, + "asset": { + "dir": "google_robot", + "model_xml": "robot.xml", + "scene_xml": "scene.xml" + }, + "aliases": [ + "oxe_google" + ] } - }, - "aliases": [ - "bdx", - "mini_bdx", - "open_duck", - "open_duck_mini_v2", - "open_duck_v2" - ] - }, - "asimov_v0": { - "description": "Asimov V0 Bipedal Legs (12-DOF + 2 passive toes)", - "category": "humanoid", - "joints": 15, - "asset": { - "dir": "asimov_v0", - "model_xml": "asimov_v0.xml", - "scene_xml": "scene.xml", - "source": { - "type": "github", - "repo": "asimovinc/asimov-v0", - "subdir": "sim-model" - } - }, - "aliases": [ - "asimov" - ] - }, - "reachy_mini": { - "description": "Pollen Reachy Mini (6-DOF Stewart head + antennas, 9 actuators)", - "category": "expressive", - "joints": 21, - "asset": { - "dir": "reachy_mini", - "model_xml": "mjcf/reachy_mini.xml", - "scene_xml": "mjcf/scene.xml", - "source": { - "type": "github", - "repo": "pollen-robotics/reachy_mini", - "subdir": "src/reachy_mini/descriptions/reachy_mini" - } - }, - "aliases": [ - "pollen_reachy_mini", - "reachy", - "reachy-mini", - "reachymini" - ] - }, - "unitree_go2": { - "description": "Unitree Go2 Quadruped", - "category": "mobile", - "joints": 40, - "asset": { - "dir": "unitree_go2", - "model_xml": "go2.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "go2_mj_description" - }, - "aliases": [ - "go2" - ] - }, - "unitree_a1": { - "description": "Unitree A1 Quadruped", - "category": "mobile", - "joints": 16, - "asset": { - "dir": "unitree_a1", - "model_xml": "a1.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "a1_mj_description" - }, - "aliases": [ - "a1" - ] - }, - "spot": { - "description": "Boston Dynamics Spot (with arm)", - "category": "mobile", - "joints": 20, - "asset": { - "dir": "boston_dynamics_spot", - "model_xml": "spot_arm.xml", - "scene_xml": "scene_arm.xml", - "robot_descriptions_module": "spot_mj_description" - }, - "aliases": [ - "boston_dynamics_spot" - ] - }, - "stretch3": { - "description": "Hello Robot Stretch 3 (mobile manipulator)", - "category": "mobile", - "joints": 41, - "asset": { - "dir": "hello_robot_stretch_3", - "model_xml": "stretch.xml", - "scene_xml": "scene.xml", - "robot_descriptions_module": "stretch_3_mj_description" - }, - "aliases": [ - "hello_robot_stretch", - "hello_robot_stretch_3" - ] - }, - "google_robot": { - "description": "Google Robot (mobile base + arm, RT-X)", - "category": "mobile_manip", - "joints": 10, - "asset": { - "dir": "google_robot", - "model_xml": "robot.xml", - "scene_xml": "scene.xml" - }, - "aliases": [ - "oxe_google" - ] - }, - "lekiwi": { - "description": "LeKiwi mobile robot", - "category": "mobile", - "hardware": { - "lerobot_type": "lekiwi" - } - }, - "reachy2": { - "description": "Pollen Reachy 2", - "category": "humanoid", - "hardware": { - "lerobot_type": "reachy2" - } - }, - "hope_jr": { - "description": "Hope Junior arm", - "category": "arm", - "hardware": { - "lerobot_type": "hope_jr" - } - }, - "earthrover": { - "description": "EarthRover Mini Plus (mobile outdoor navigation)", - "category": "mobile", - "hardware": { - "lerobot_type": "earthrover" - }, - "aliases": [ - "earth_rover", - "earthrover_mini_plus", - "frodobots" - ] - }, - "omx": { - "description": "OMX Robot Arm (ROBOTIS, CAN bus motors)", - "category": "arm", - "hardware": { - "lerobot_type": "omx" - }, - "aliases": [ - "omx_follower", - "omx_robot", - "robotis_omx" - ] - }, - "bi_openarm": { - "description": "Bi-manual OpenArm (dual-arm coordination)", - "category": "bimanual", - "hardware": { - "lerobot_type": "bi_openarm" - }, - "aliases": [ - "bi_openarm_follower", - "dual_openarm", - "openarm_bimanual" - ] } - } } diff --git a/strands_robots/registry/user_registry.py b/strands_robots/registry/user_registry.py new file mode 100644 index 0000000..bce87c0 --- /dev/null +++ b/strands_robots/registry/user_registry.py @@ -0,0 +1,280 @@ +"""User-local robot registry — runtime registration without editing package JSON. + +Provides ``register_robot()`` and ``unregister_robot()`` for adding custom +robots that persist across sessions via a ``user_robots.json`` file stored +alongside the asset cache. + +File location (in priority order): + 1. ``$STRANDS_ASSETS_DIR/user_robots.json`` + 2. ``~/.strands_robots/user_robots.json`` + +At load time the user overlay is merged *on top of* the package +``robots.json`` — user entries win on name collision, so you can also +override built-in robots locally. + +Usage:: + + from strands_robots.registry import register_robot, unregister_robot + + # Register a custom robot with MJCF + register_robot( + name="my_arm", + model_xml="my_arm.xml", + description="My custom 6-DOF arm", + category="arm", + joints=6, + asset_dir="my_arm", # resolved relative to assets dir + ) + + # Now works everywhere: + from strands_robots.simulation import create_simulation + sim = create_simulation() + sim.create_world() + sim.add_robot("my_arm") # ✅ auto-resolved + + # Remove it + unregister_robot("my_arm") +""" + +import json +import logging +from pathlib import Path +from typing import Any + +from strands_robots.utils import get_base_dir, resolve_asset_path + +from .loader import invalidate_cache + +logger = logging.getLogger(__name__) + + +def _get_user_registry_path() -> Path: + """Get path to the user-local robot registry file.""" + return get_base_dir() / "user_robots.json" + + +def _load_user_registry() -> dict[str, Any]: + """Load the user-local robot registry file. + + Returns: + Dict with ``"robots"`` key mapping names to robot definitions. + """ + path = _get_user_registry_path() + if not path.exists(): + return {"robots": {}} + try: + with open(path, encoding="utf-8") as f: + data = json.load(f) + if "robots" not in data: + data = {"robots": {}} + return data + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Failed to load user registry %s: %s", path, exc) + return {"robots": {}} + + +def _save_user_registry(data: dict[str, Any]) -> None: + """Save the user-local robot registry file.""" + path = _get_user_registry_path() + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=4) + f.write("\n") + logger.info("Saved user registry: %s (%d robots)", path, len(data.get("robots", {}))) + + +def get_user_robots() -> dict[str, Any]: + """Get all user-registered robots. + + Returns: + Dict mapping robot names to their definitions. + """ + return _load_user_registry().get("robots", {}) + + +def register_robot( + name: str, + *, + model_xml: str, + description: str = "", + category: str = "arm", + joints: int = 0, + asset_dir: str | None = None, + scene_xml: str | None = None, + aliases: list[str] | None = None, + robot_descriptions_module: str | None = None, + hardware: dict[str, Any] | None = None, + overwrite: bool = False, +) -> dict[str, Any]: + """Register a custom robot in the user-local registry. + + .. warning:: Security + + This function is a **library-only** API and must NOT be exposed + as an agent @tool without additional safeguards. A malicious + agent could register a robot pointing to attacker-controlled MJCF + that executes code via MuJoCo plugins. If tool exposure is needed + in the future, gate it behind STRANDS_TRUST_REMOTE_CODE and + validate all paths with _safe_join. + + The robot becomes immediately available in ``get_robot()``, + ``list_robots()``, ``resolve_model_path()``, ``sim.add_robot()``, etc. + + Args: + name: Canonical robot name (lowercase, underscores). + model_xml: Path to MJCF/URDF model file, relative to ``asset_dir``. + description: Human-readable description. + category: Robot category (arm, humanoid, mobile, hand, aerial, bimanual, ...). + joints: Number of actuated joints. + asset_dir: Directory containing the model file and meshes. + - Absolute path: used as-is (``~/`` expanded). + - Relative path: resolved against the assets directory + (``STRANDS_ASSETS_DIR`` or ``~/.strands_robots/assets/``). + - None: defaults to ``//``. + scene_xml: Scene XML (with ground/lights). Defaults to ``model_xml``. + aliases: Alternative names for this robot. + robot_descriptions_module: Optional ``robot_descriptions`` module name. + hardware: Optional hardware config dict (``lerobot_type``, etc.). + overwrite: If False (default), raises ValueError if robot already exists. + + Returns: + The registered robot definition dict. + + Raises: + ValueError: If name already exists and ``overwrite`` is False. + FileNotFoundError: If ``model_xml`` doesn't exist at the resolved path. + + Example:: + + register_robot( + name="my_arm", + model_xml="my_arm.xml", + asset_dir="~/robots/my_arm_v2", + description="Custom 6-DOF arm with gripper", + category="arm", + joints=7, + aliases=["myarm", "custom_arm"], + ) + """ + # Normalize name + name = name.lower().strip().replace("-", "_") + + # Load existing + data = _load_user_registry() + + # Check for existing (in user registry AND package registry) + if not overwrite: + if name in data.get("robots", {}): + raise ValueError(f"Robot '{name}' already in user registry. Use overwrite=True to replace.") + # Also check package registry + try: + from .robots import get_robot as _pkg_get_robot + + if _pkg_get_robot(name) is not None: + logger.info( + "Robot '%s' exists in package registry — user registration will override it.", + name, + ) + except ImportError: + pass + + # Resolve asset_dir via shared utility (respects STRANDS_ASSETS_DIR) + resolved_dir = resolve_asset_path(asset_dir, default_name=name) + + # Use the directory name as the asset "dir" key (relative to search paths) + # This matches how resolve_model_path works: search_dir / asset["dir"] / xml + dir_name = resolved_dir.name + + # Validate model_xml exists (if asset_dir exists) + model_path = resolved_dir / model_xml + if resolved_dir.exists() and not model_path.exists(): + raise FileNotFoundError(f"Model XML not found: {model_path}\nEnsure '{model_xml}' exists in '{resolved_dir}'") + + # Build entry + entry: dict[str, Any] = { + "description": description, + "category": category, + "joints": joints, + "asset": { + "dir": dir_name, + "model_xml": model_xml, + "scene_xml": scene_xml or model_xml, + }, + } + + if robot_descriptions_module: + entry["asset"]["robot_descriptions_module"] = robot_descriptions_module + + if aliases: + entry["aliases"] = aliases + + if hardware: + entry["hardware"] = hardware + + # Store the full resolved path so the asset manager can find it + # even if the dir isn't in the standard search paths + entry["_user_asset_path"] = str(resolved_dir) + + # Save + data.setdefault("robots", {})[name] = entry + _save_user_registry(data) + + # Invalidate loader cache so next get_robot() picks up the merge + _invalidate_cache() + + logger.info("Registered robot '%s' → %s/%s", name, dir_name, model_xml) + return entry + + +def unregister_robot(name: str) -> bool: + """Remove a robot from the user-local registry. + + Does not affect the package ``robots.json``. If the robot exists + only in the package registry, this is a no-op. + + Args: + name: Robot name to remove. + + Returns: + True if the robot was removed, False if it wasn't in the user registry. + """ + name = name.lower().strip().replace("-", "_") + data = _load_user_registry() + + if name not in data.get("robots", {}): + logger.info("Robot '%s' not in user registry — nothing to remove.", name) + return False + + del data["robots"][name] + _save_user_registry(data) + _invalidate_cache() + + logger.info("Unregistered robot '%s'", name) + return True + + +def list_user_robots() -> list[dict[str, Any]]: + """List all user-registered robots. + + Returns: + List of dicts with name, description, category, path info. + """ + robots = get_user_robots() + result = [] + for name, info in sorted(robots.items()): + result.append( + { + "name": name, + "description": info.get("description", ""), + "category": info.get("category", ""), + "joints": info.get("joints", 0), + "asset_dir": info.get("_user_asset_path", ""), + "model_xml": info.get("asset", {}).get("model_xml", ""), + } + ) + return result + + +def _invalidate_cache() -> None: + """Invalidate the loader cache so merged data is reloaded.""" + invalidate_cache("robots") diff --git a/strands_robots/simulation/__init__.py b/strands_robots/simulation/__init__.py new file mode 100644 index 0000000..d9674a9 --- /dev/null +++ b/strands_robots/simulation/__init__.py @@ -0,0 +1,111 @@ +"""Strands Robots Simulation — multi-backend simulation framework. + +Architecture:: + + simulation/ + ├── __init__.py ← this file (re-exports, lazy loading) + ├── base.py ← SimEngine ABC + ├── factory.py ← create_simulation() + backend registration + ├── models.py ← shared dataclasses (SimWorld, SimRobot, ...) + └── model_registry.py ← URDF/MJCF resolution (shared across backends) + + # MuJoCo backend added in subsequent PRs. + +Usage:: + + # Default (MuJoCo) via factory + from strands_robots.simulation import create_simulation + sim = create_simulation() + + # Direct class access + from strands_robots.simulation import Simulation + sim = Simulation() + + # Explicit backend + from strands_robots.simulation.mujoco import MuJoCoSimulation + + # Shared types (no heavy deps) + from strands_robots.simulation import SimWorld, SimRobot, SimObject + + # ABC for custom backends + from strands_robots.simulation.base import SimEngine + +Future backends:: + + from strands_robots.simulation.isaac import IsaacSimulation + from strands_robots.simulation.newton import NewtonSimulation +""" + +import importlib as _importlib +from typing import Any + +# --- Light imports (no heavy deps — stdlib + dataclasses only) --- +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.factory import ( + create_simulation, + list_backends, + register_backend, +) +from strands_robots.simulation.model_registry import ( + list_available_models, + list_registered_urdfs, + register_urdf, + resolve_model, + resolve_urdf, +) +from strands_robots.simulation.models import ( + SimCamera, + SimObject, + SimRobot, + SimStatus, + SimWorld, + TrajectoryStep, +) + +# --- Heavy imports (lazy — loaded when mujoco backend is available) --- +# MuJoCo-specific lazy imports will be added when the mujoco/ subpackage +# is introduced. For now, only the lightweight foundation is available. +_LAZY_IMPORTS: dict[str, tuple[str, str]] = {} + + +__all__ = [ + # ABC + "SimEngine", + # Factory + "create_simulation", + "list_backends", + "register_backend", + # Default backend alias (available when mujoco backend is installed) + # "Simulation", + # "MuJoCoSimulation", + # Shared dataclasses + "SimStatus", + "SimRobot", + "SimObject", + "SimCamera", + "SimWorld", + "TrajectoryStep", + # MuJoCo builder (available when mujoco backend is installed) + # "MJCFBuilder", + # Model registry + "register_urdf", + "resolve_model", + "resolve_urdf", + "list_registered_urdfs", + "list_available_models", +] + + +def __getattr__(name: str) -> Any: + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + module = _importlib.import_module(module_path) + value = getattr(module, attr_name) + globals()[name] = value + return value + raise AttributeError(f"module 'strands_robots.simulation' has no attribute {name!r}") + + +# NOTE: MuJoCo GL backend configuration lives in the top-level +# strands_robots/__init__.py to ensure it runs before any `import mujoco`. +# Do NOT duplicate it here — see PR #86 for the canonical location. diff --git a/strands_robots/simulation/base.py b/strands_robots/simulation/base.py new file mode 100644 index 0000000..7ca2098 --- /dev/null +++ b/strands_robots/simulation/base.py @@ -0,0 +1,222 @@ +"""Simulation ABC — backend-agnostic interface for all simulation engines. + +Every simulation backend (MuJoCo, Isaac, Newton) implements this interface. +Agent tools and the Robot() factory interact through these methods only — +they never touch backend-specific APIs directly. + +Usage:: + + from strands_robots.simulation import Simulation # returns MuJoCo by default + + # Or explicitly: + from strands_robots.simulation.mujoco import MuJoCoSimulation + + # Future: + from strands_robots.simulation.isaac import IsaacSimulation + from strands_robots.simulation.newton import NewtonSimulation +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger(__name__) + + +class SimEngine(ABC): + """Abstract base class for simulation engines. + + Defines the contract that all backends (MuJoCo, Isaac, Newton) must + implement. This is the *programmatic* API — the AgentTool layer + wraps it with tool_spec/stream for LLM access. + + Method categories: + + **Required** (``@abstractmethod``): Core simulation loop — world + lifecycle, entity management, observation/action, rendering. Every + physics engine must implement these to be usable. + + **Optional** (default raises ``NotImplementedError``): Higher-level + features — scene loading, policy running, domain randomization, + contact queries. Backends opt in by overriding only what they support. + + Lifecycle:: + + sim = SomeEngine() + sim.create_world() + sim.add_robot("so100", data_config="so100") + sim.add_object("cube", shape="box", position=[0.3, 0, 0.05]) + + # Control loop + obs = sim.get_observation("so100") + sim.send_action({"joint_0": 0.5}, robot_name="so100") + sim.step(n_steps=10) + + # Render + result = sim.render(camera_name="default") + + # Cleanup + sim.destroy() + """ + + # --- World lifecycle --- + + @abstractmethod + def create_world( + self, + timestep: float | None = None, + gravity: list[float] | None = None, + ground_plane: bool = True, + ) -> dict[str, Any]: + """Create a new simulation world.""" + ... + + @abstractmethod + def destroy(self) -> dict[str, Any]: + """Destroy the simulation world and release resources.""" + ... + + @abstractmethod + def reset(self) -> dict[str, Any]: + """Reset simulation to initial state.""" + ... + + @abstractmethod + def step(self, n_steps: int = 1) -> dict[str, Any]: + """Advance simulation by n physics steps.""" + ... + + @abstractmethod + def get_state(self) -> dict[str, Any]: + """Get full simulation state summary.""" + ... + + # --- Robot management --- + + @abstractmethod + def add_robot( + self, + name: str, + urdf_path: str | None = None, + data_config: str | None = None, + position: list[float] | None = None, + orientation: list[float] | None = None, + ) -> dict[str, Any]: + """Add a robot to the simulation.""" + ... + + @abstractmethod + def remove_robot(self, name: str) -> dict[str, Any]: + """Remove a robot from the simulation.""" + ... + + # --- Object management --- + + @abstractmethod + def add_object( + self, + name: str, + shape: str = "box", + position: list[float] | None = None, + orientation: list[float] | None = None, + size: list[float] | None = None, + color: list[float] | None = None, + mass: float = 0.1, + is_static: bool = False, + mesh_path: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Add an object to the scene.""" + ... + + @abstractmethod + def remove_object(self, name: str) -> dict[str, Any]: + """Remove an object from the scene.""" + ... + + # --- Observation / Action --- + + @abstractmethod + def get_observation(self, robot_name: str | None = None, camera_name: str | None = None) -> dict[str, Any]: + """Get observation from simulation. + + Convenience method that delegates to the underlying Robot + abstraction. Provides a unified interface for agent tools + that interact with simulation without needing to distinguish + between Robot and Sim layers. + """ + ... + + @abstractmethod + def send_action(self, action: dict[str, Any], robot_name: str | None = None, n_substeps: int = 1) -> None: + """Apply action to simulation. + + Convenience method that delegates to the underlying Robot + abstraction. The simulation engine acts as a facade so agent + tools can use ``sim.send_action()`` without knowing about + the Robot/Policy layer. + """ + ... + + # --- Rendering --- + + @abstractmethod + def render( + self, camera_name: str = "default", width: int | None = None, height: int | None = None + ) -> dict[str, Any]: + """Render a camera view. + + Returns dict with ``"image"`` key (numpy array, RGB uint8) and + optional ``"depth"`` key (float32 depth map). Resolution comes + from camera config unless ``width``/``height`` are given. + """ + ... + + # --- Optional overrides (have default no-op implementations) --- + + def load_scene(self, scene_path: str) -> dict[str, Any]: + """Load a complete scene from file. Override per backend.""" + raise NotImplementedError("load_scene not implemented by this backend") + + def run_policy(self, robot_name: str, policy_provider: str = "mock", **kwargs: Any) -> dict[str, Any]: + """Run a policy loop in the simulation. + + Orchestration shortcut: internally creates a Policy, then loops + ``obs → policy(obs) → send_action(action) → step()``. + Intentionally placed on SimEngine as a facade for agent tools + that need a single ``simulation(action="run_policy")`` interface. + Override per backend. + """ + raise NotImplementedError("run_policy not implemented by this backend") + + def randomize(self, **kwargs: Any) -> dict[str, Any]: + """Apply domain randomization. + + Concrete backends define their own parameter signatures. + Override per backend. + """ + raise NotImplementedError("randomize not implemented by this backend") + + def get_contacts(self) -> dict[str, Any]: + """Get contact information. Override per backend.""" + raise NotImplementedError("get_contacts not implemented by this backend") + + def cleanup(self) -> None: + """Release all resources. Called on __del__ / context exit.""" + pass + + def __enter__(self) -> SimEngine: + return self + + def __exit__(self, *exc: object) -> None: + self.cleanup() + + def __del__(self) -> None: + try: + self.cleanup() + except Exception as e: + # Best-effort cleanup during GC — exceptions can't propagate + # from __del__ (CPython ignores them), so log for visibility. + logger.warning("Cleanup error during __del__: %s", e) diff --git a/strands_robots/simulation/factory.py b/strands_robots/simulation/factory.py new file mode 100644 index 0000000..6785d67 --- /dev/null +++ b/strands_robots/simulation/factory.py @@ -0,0 +1,216 @@ +"""Simulation factory — create_simulation() and runtime backend registration. + +Mirrors the policy factory pattern: JSON-driven defaults with runtime +override capability. Backends are lazy-loaded on first use. + +Usage:: + + from strands_robots.simulation import create_simulation + + # Default backend (MuJoCo) + sim = create_simulation() + + # Explicit backend + sim = create_simulation("mujoco", timestep=0.001) + + # GPU-native Newton backend + sim = create_simulation("newton", num_envs=4096, solver="mujoco") + + # Custom backend (runtime-registered) + from strands_robots.simulation.factory import register_backend + register_backend("my_sim", lambda: MySimBackend, aliases=["custom"]) + sim = create_simulation("custom") +""" + +from __future__ import annotations + +import importlib +import logging +from collections.abc import Callable +from typing import Any + +from strands_robots.simulation.base import SimEngine + +logger = logging.getLogger(__name__) + +# ───────────────────────────────────────────────────────────────────── +# Built-in backend registry (lazy loaders — no imports at module load) +# ───────────────────────────────────────────────────────────────────── + +_BUILTIN_BACKENDS: dict[str, tuple[str, str]] = { + "mujoco": ( + "strands_robots.simulation.mujoco.simulation", + "Simulation", + ), + "newton": ( + "strands_robots.simulation.newton.simulation", + "NewtonSimulation", + ), + # Future: + # "isaac": ("strands_robots.simulation.isaac.simulation", "IsaacSimulation"), +} + +_BUILTIN_ALIASES: dict[str, str] = { + "mj": "mujoco", + "mjc": "mujoco", + "mjx": "mujoco", + "warp": "newton", + # "isaac_sim": "isaac", + # "isaacsim": "isaac", + # "nvidia": "isaac", +} + +DEFAULT_BACKEND = "mujoco" + +# ───────────────────────────────────────────────────────────────────── +# Runtime registration (for user-defined backends not in built-ins) +# ───────────────────────────────────────────────────────────────────── + +_runtime_registry: dict[str, Callable[[], type[SimEngine]]] = {} +_runtime_aliases: dict[str, str] = {} + + +def register_backend( + name: str, + loader: Callable[[], type[SimEngine]], + aliases: list[str] | None = None, + force: bool = False, +) -> None: + """Register a custom simulation backend at runtime. + + Use this to add backends without editing source code. + + Args: + name: Backend identifier (e.g., ``"my_physics"``). + loader: Zero-arg callable that returns the backend **class** + (not instance). Called lazily on first ``create_simulation()``. + aliases: Optional short names that resolve to ``name``. + force: If False (default), raises ValueError when ``name`` or + an alias is already registered. Set True to overwrite. + + Raises: + ValueError: If ``name`` or an alias conflicts with an existing + registration and ``force`` is False. + + Example:: + + from strands_robots.simulation.factory import register_backend + + register_backend( + "bullet", + lambda: BulletSimulation, + aliases=["pybullet", "pb"], + ) + sim = create_simulation("bullet") + """ + if not force: + if name in _runtime_registry or name in _BUILTIN_BACKENDS: + raise ValueError(f"Backend {name!r} already registered. Use force=True to overwrite.") + if aliases: + for alias in aliases: + if alias in _BUILTIN_ALIASES: + raise ValueError(f"Alias {alias!r} conflicts with built-in alias. Use force=True to overwrite.") + if alias in _runtime_aliases: + raise ValueError(f"Alias {alias!r} already registered. Use force=True to overwrite.") + + _runtime_registry[name] = loader + if aliases: + for alias in aliases: + _runtime_aliases[alias] = name + logger.debug("Registered simulation backend: %s (aliases=%s)", name, aliases) + + +def list_backends() -> list[str]: + """List all available backend names (built-in + runtime-registered). + + Returns: + Sorted list of unique backend identifiers and aliases. + + Example:: + + >>> list_backends() + ['mj', 'mjc', 'mjx', 'mujoco', 'newton', 'warp'] + """ + names: set[str] = set() + names.update(_BUILTIN_BACKENDS.keys()) + names.update(_BUILTIN_ALIASES.keys()) + names.update(_runtime_registry.keys()) + names.update(_runtime_aliases.keys()) + return sorted(names) + + +def _resolve_name(backend: str) -> str: + """Resolve aliases to canonical backend name.""" + # Runtime aliases first (user overrides win) + if backend in _runtime_aliases: + return _runtime_aliases[backend] + # Built-in aliases + if backend in _BUILTIN_ALIASES: + return _BUILTIN_ALIASES[backend] + return backend + + +def _import_backend_class(name: str) -> type[SimEngine]: + """Import and return a backend class by canonical name.""" + # 1. Runtime registry (user-registered) + if name in _runtime_registry: + cls: type[SimEngine] = _runtime_registry[name]() + logger.debug("Loaded runtime backend: %s → %s", name, cls.__name__) + return cls + + # 2. Built-in registry + if name in _BUILTIN_BACKENDS: + module_path, class_name = _BUILTIN_BACKENDS[name] + module = importlib.import_module(module_path) + backend_cls: type[SimEngine] = getattr(module, class_name) # type: ignore[assignment] + logger.debug("Loaded built-in backend: %s → %s.%s", name, module_path, class_name) + return backend_cls + + raise ValueError(f"Unknown simulation backend: {name!r}. Available: {', '.join(list_backends())}") + + +def create_simulation( + backend: str = DEFAULT_BACKEND, + **kwargs: Any, +) -> SimEngine: + """Create a simulation backend instance. + + This is the primary entry point for creating simulations. + Backend classes are lazy-loaded on first call. + + Args: + backend: Backend name or alias. Defaults to ``"mujoco"``. + Built-in: ``"mujoco"`` (aliases: ``"mj"``, ``"mjc"``, ``"mjx"``), + ``"newton"`` (alias: ``"warp"``). + **kwargs: Backend-specific keyword arguments passed to the + constructor (e.g., ``num_envs``, ``solver``). + + Returns: + A ``SimEngine`` instance ready for ``create_world()``. + + Raises: + ValueError: If the backend name is not recognized. + ImportError: If the backend's dependencies are missing + (e.g., ``pip install mujoco``). + + Examples:: + + # Default (MuJoCo) + sim = create_simulation() + sim.create_world() + sim.add_robot("so100") + + # Newton GPU backend + sim = create_simulation("newton", num_envs=4096) + + # With alias + sim = create_simulation("warp") + + # Pass kwargs to backend constructor + sim = create_simulation("mujoco", tool_name="my_sim") + """ + canonical = _resolve_name(backend) + logger.info("Creating simulation: %s (resolved from %r)", canonical, backend) + + BackendClass = _import_backend_class(canonical) + return BackendClass(**kwargs) diff --git a/strands_robots/simulation/model_registry.py b/strands_robots/simulation/model_registry.py new file mode 100644 index 0000000..a52028e --- /dev/null +++ b/strands_robots/simulation/model_registry.py @@ -0,0 +1,136 @@ +"""Robot model resolution — URDF registry + asset manager. + +Bridges the robot registry with actual URDF/MJCF files on disk. + +Resolution order for :func:`resolve_model`: + 1. User-registered URDFs (:func:`register_urdf`) + 2. URDF search paths (``STRANDS_ASSETS_DIR``, CWD, etc.) + 3. Asset manager (``robot_descriptions`` — fallback for standard robots) +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path + +from strands_robots.utils import get_assets_dir + +logger = logging.getLogger(__name__) + +# Default URDF search paths (checked in order). +# +# Resolution order for user-registered URDF lookups: +# 1. STRANDS_ASSETS_DIR (if set) — user override (via utils.get_assets_dir) +# 2. CWD/assets/ — project-local assets +_URDF_SEARCH_PATHS = [ + get_assets_dir(), + Path.cwd() / "assets", +] + +try: + from strands_robots.assets import ( + format_robot_table, + resolve_model_path, + ) + + _HAS_ASSET_MANAGER = True +except ImportError: + _HAS_ASSET_MANAGER = False + +try: + from strands_robots.registry import get_robot, resolve_name + + _HAS_REGISTRY = True +except ImportError: + _HAS_REGISTRY = False + +logger.info("Asset manager available: %s", _HAS_ASSET_MANAGER) + +# Runtime cache for user-registered URDFs +_URDF_REGISTRY: dict[str, str] = {} + + +def register_urdf(data_config: str, urdf_path: str) -> None: + """Register a URDF/MJCF file for a data_config name.""" + _URDF_REGISTRY[data_config] = urdf_path + logger.info("📋 Registered model for '%s': %s", data_config, urdf_path) + + +def resolve_model(name: str, prefer_scene: bool = True) -> str | None: + """Resolve a robot name or data_config to an MJCF/URDF model path. + + Resolution order (local assets take priority): + 1. User-registered URDFs (custom user registrations) + 2. URDF search paths (STRANDS_ASSETS_DIR, CWD, etc.) + 3. Asset manager (robot_descriptions — fallback for standard robots) + """ + # 1+2. Check local/custom paths first (user overrides win) + local = resolve_urdf(name) + if local: + return local + + # 3. Fall back to asset manager + if _HAS_ASSET_MANAGER: + path = resolve_model_path(name, prefer_scene=prefer_scene) + if path and path.exists(): + return str(path) + if prefer_scene: + path = resolve_model_path(name, prefer_scene=False) + if path and path.exists(): + return str(path) + + return None + + +def resolve_urdf(data_config: str) -> str | None: + """Resolve a data_config name to a URDF file path. + + Also checks the registry's ``legacy_urdf`` field — a backward-compatible + path for robots that were registered before the MJCF asset system + was introduced (e.g. robots originally configured with raw URDF paths). + """ + if data_config in _URDF_REGISTRY: + urdf_rel = _URDF_REGISTRY[data_config] + if os.path.isabs(urdf_rel) and os.path.exists(urdf_rel): + return str(urdf_rel) + for search_dir in _URDF_SEARCH_PATHS: + candidate = search_dir / urdf_rel + if candidate.exists(): + return str(candidate) + + if _HAS_REGISTRY: + canonical = resolve_name(data_config) + info = get_robot(canonical) + # ``legacy_urdf``: backward-compatible URDF path from before the + # MJCF asset system was introduced. Kept so that existing + # user configs referencing raw URDF paths continue to work. + if info and "legacy_urdf" in info: + urdf_rel = info["legacy_urdf"] + if os.path.isabs(urdf_rel) and os.path.exists(urdf_rel): + return str(urdf_rel) + for search_dir in _URDF_SEARCH_PATHS: + candidate = search_dir / urdf_rel + if candidate.exists(): + return str(candidate) + + logger.debug("URDF not found for '%s' in search paths", data_config) + return None + + +def list_registered_urdfs() -> dict[str, str | None]: + """List all registered URDF mappings and their resolved paths.""" + return {config_name: resolve_urdf(config_name) for config_name in _URDF_REGISTRY} + + +def list_available_models() -> str: + """List all available robot models (Menagerie + custom).""" + if _HAS_ASSET_MANAGER: + return str(format_robot_table()) + + lines = ["Registered URDFs:"] + for name, path in _URDF_REGISTRY.items(): + resolved = resolve_urdf(name) + status = "✅" if resolved else "❌" + lines.append(f" {status} {name}: {path}") + return "\n".join(lines) diff --git a/strands_robots/simulation/models.py b/strands_robots/simulation/models.py new file mode 100644 index 0000000..c945b01 --- /dev/null +++ b/strands_robots/simulation/models.py @@ -0,0 +1,132 @@ +"""Dataclasses for simulation state. + +These dataclasses provide a backend-independent typed state representation +consumed by simulation engine implementations (e.g. MuJoCo, Isaac Sim, +PyBullet). + +They enable: + - Type-safe state tracking across simulation steps. + - Serialisation for checkpoints and trajectory recording. + - A backend-independent interface for agent tools. + +They are defined alongside the ``SimEngine`` ABC because its method +signatures reference them (e.g. ``create_world() → SimWorld``). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class SimStatus(Enum): + """Simulation execution status.""" + + IDLE = "idle" + RUNNING = "running" + PAUSED = "paused" + COMPLETED = "completed" + ERROR = "error" + + +@dataclass +class SimRobot: + """A robot instance within the simulation.""" + + name: str + urdf_path: str + position: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) + orientation: list[float] = field(default_factory=lambda: [1.0, 0.0, 0.0, 0.0]) # wxyz quat + data_config: str | None = None + body_id: int = -1 + joint_ids: list[int] = field(default_factory=list) + joint_names: list[str] = field(default_factory=list) + actuator_ids: list[int] = field(default_factory=list) + namespace: str = "" + policy_running: bool = False + policy_steps: int = 0 + policy_instruction: str = "" + + +@dataclass +class SimObject: + """An object in the simulation scene.""" + + name: str + shape: str # "box", "sphere", "cylinder", "capsule", "mesh" + position: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) + orientation: list[float] = field(default_factory=lambda: [1.0, 0.0, 0.0, 0.0]) + size: list[float] = field(default_factory=lambda: [0.05, 0.05, 0.05]) + color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5, 1.0]) # RGBA + mass: float = 0.1 + mesh_path: str | None = None + body_id: int = -1 + is_static: bool = False + _original_position: list[float] = field(default_factory=list) + _original_color: list[float] = field(default_factory=list) + + def __post_init__(self) -> None: + self._original_position = list(self.position) + self._original_color = list(self.color) + + +@dataclass +class SimCamera: + """A camera in the simulation.""" + + name: str + position: list[float] = field(default_factory=lambda: [1.0, 1.0, 1.0]) + target: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.0]) + fov: float = 60.0 + width: int = 640 + height: int = 480 + camera_id: int = -1 + + +@dataclass +class TrajectoryStep: + """A single step in a recorded trajectory.""" + + timestamp: float + sim_time: float + robot_name: str + observation: dict[str, Any] + action: dict[str, Any] + instruction: str = "" + + +@dataclass +class SimWorld: + """Complete simulation world state. + + Backend-independent state with engine-specific internals isolated in + ``_model``, ``_data``, and ``_backend_state`` — all typed as ``Any`` + or ``dict`` so that each backend can store its own native handles + without leaking implementation details into this base module. + """ + + robots: dict[str, SimRobot] = field(default_factory=dict) + objects: dict[str, SimObject] = field(default_factory=dict) + cameras: dict[str, SimCamera] = field(default_factory=dict) + timestep: float = 0.002 # 500Hz physics + gravity: list[float] = field(default_factory=lambda: [0.0, 0.0, -9.81]) + ground_plane: bool = True + status: SimStatus = SimStatus.IDLE + sim_time: float = 0.0 + step_count: int = 0 + # Engine-specific internals (set after world is built by the backend). + # Each backend stores its own native handles here. + _model: Any = None # Engine-specific model handle (e.g. MjModel, Scene) + _data: Any = None # Engine-specific data handle (e.g. MjData, World) + # Backend-specific state bag — backends store format-specific data here + # instead of polluting this base class with implementation details. + # E.g. MuJoCo stores {"xml": str, "robot_base_xml": str, "tmpdir": ...} + _backend_state: dict[str, Any] = field(default_factory=dict) + # Trajectory recording + _recording: bool = False + _trajectory: list[TrajectoryStep] = field(default_factory=list) + # LeRobotDataset recorder + _dataset_recorder: Any = None + # Physics state checkpoints (used by save_state/restore_state) + _checkpoints: dict[str, Any] = field(default_factory=dict) diff --git a/strands_robots/simulation/newton/__init__.py b/strands_robots/simulation/newton/__init__.py new file mode 100644 index 0000000..6d5ec44 --- /dev/null +++ b/strands_robots/simulation/newton/__init__.py @@ -0,0 +1,52 @@ +"""Newton/Warp GPU-accelerated simulation backend. + +Provides ``NewtonSimulation(SimEngine)`` for GPU-native physics with 4096+ +parallel environments, differentiable simulation, and 7 solver backends. + +Heavy dependencies (``warp-lang``, ``newton-sim``) are lazy-imported — this +module is safe to import without triggering GPU initialization. + +Usage:: + + from strands_robots.simulation import create_simulation + + sim = create_simulation("newton", num_envs=4096, solver="mujoco") + sim.create_world() + sim.add_robot("so100") + + # Or direct import + from strands_robots.simulation.newton import NewtonSimulation, NewtonConfig +""" + +from __future__ import annotations + +import importlib as _importlib +from typing import Any + +# Light re-exports (no heavy deps) +from strands_robots.simulation.newton.config import NewtonConfig +from strands_robots.simulation.newton.solvers import SOLVER_MAP + +# Lazy-loaded heavy exports +_LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "NewtonSimulation": ( + "strands_robots.simulation.newton.simulation", + "NewtonSimulation", + ), +} + +__all__ = [ + "NewtonConfig", + "NewtonSimulation", + "SOLVER_MAP", +] + + +def __getattr__(name: str) -> Any: + if name in _LAZY_IMPORTS: + module_path, attr_name = _LAZY_IMPORTS[name] + module = _importlib.import_module(module_path) + value = getattr(module, attr_name) + globals()[name] = value + return value + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/strands_robots/simulation/newton/config.py b/strands_robots/simulation/newton/config.py new file mode 100644 index 0000000..10efdec --- /dev/null +++ b/strands_robots/simulation/newton/config.py @@ -0,0 +1,79 @@ +"""Configuration for the Newton simulation backend. + +Validates all user-supplied configuration at construction time so that +errors surface during setup rather than during inference. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from strands_robots.simulation.newton.solvers import ( + BROAD_PHASE_OPTIONS, + RENDER_BACKENDS, + SOLVER_MAP, +) + + +@dataclass +class NewtonConfig: + """Configuration for the Newton simulation backend. + + Parameters + ---------- + num_envs : int + Number of parallel environments. Set to 4096+ for GPU training. + device : str + Warp device string (``"cuda:0"``, ``"cpu"``). + solver : str + Physics solver. See :data:`SOLVER_MAP` for options. + physics_dt : float + Physics timestep in seconds. + substeps : int + Physics substeps per ``step()`` call. + render_backend : str + Rendering backend (``"opengl"``, ``"rerun"``, ``"viser"``, ``"null"``). + enable_cuda_graph : bool + Capture CUDA graph on first ``step()`` for minimal Python overhead. + enable_differentiable : bool + Enable gradient tracking for differentiable simulation. + broad_phase : str + Broad-phase collision detection algorithm. + soft_contact_margin : float + Soft-contact margin distance (metres). + soft_contact_ke : float + Contact elastic stiffness. + soft_contact_kd : float + Contact damping coefficient. + soft_contact_mu : float + Friction coefficient. + soft_contact_restitution : float + Coefficient of restitution (bounciness). + """ + + num_envs: int = 1 + device: str = "cuda:0" + solver: str = "mujoco" + physics_dt: float = 0.005 + substeps: int = 1 + render_backend: str = "null" + enable_cuda_graph: bool = False + enable_differentiable: bool = False + broad_phase: str = "sap" + soft_contact_margin: float = 0.5 + soft_contact_ke: float = 10000.0 + soft_contact_kd: float = 10.0 + soft_contact_mu: float = 0.5 + soft_contact_restitution: float = 0.0 + + def __post_init__(self) -> None: + if self.solver not in SOLVER_MAP: + raise ValueError(f"Unknown solver {self.solver!r}. Available: {sorted(SOLVER_MAP)}") + if self.render_backend not in RENDER_BACKENDS: + raise ValueError(f"Unknown render_backend {self.render_backend!r}. Available: {sorted(RENDER_BACKENDS)}") + if self.broad_phase not in BROAD_PHASE_OPTIONS: + raise ValueError(f"Unknown broad_phase {self.broad_phase!r}. Available: {sorted(BROAD_PHASE_OPTIONS)}") + if self.physics_dt <= 0: + raise ValueError(f"physics_dt must be positive, got {self.physics_dt}") + if self.num_envs < 1: + raise ValueError(f"num_envs must be >= 1, got {self.num_envs}") diff --git a/strands_robots/simulation/newton/simulation.py b/strands_robots/simulation/newton/simulation.py new file mode 100644 index 0000000..2be2efe --- /dev/null +++ b/strands_robots/simulation/newton/simulation.py @@ -0,0 +1,225 @@ +"""Newton GPU-accelerated simulation backend. + +Implements :class:`~strands_robots.simulation.base.SimEngine` using +NVIDIA Warp + Newton for GPU-native physics with 4096+ parallel +environments, differentiable simulation, and 7 solver backends. + +Heavy dependencies (``warp-lang``, ``newton-sim``) are imported lazily +on first use — constructing ``NewtonSimulation`` does **not** trigger +GPU initialisation. + +See Also +-------- +strands_robots.simulation.newton.config.NewtonConfig : + Backend configuration dataclass. +strands_robots.simulation.newton.solvers.SOLVER_MAP : + Available physics solvers. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.newton.config import NewtonConfig + +logger = logging.getLogger(__name__) + + +class NewtonSimulation(SimEngine): + """GPU-native simulation backend built on NVIDIA Warp + Newton. + + This is a **stub** implementation. All abstract methods raise + ``NotImplementedError`` until subsequent PRs land the real logic. + The stub exists so that: + + 1. ``create_simulation("newton")`` resolves and returns an instance. + 2. The factory registry is exercised in CI without GPU dependencies. + 3. Downstream PRs can build on a stable class hierarchy. + + Parameters + ---------- + config : NewtonConfig | None + Backend configuration. If ``None``, defaults are used. + **kwargs : Any + Forwarded to config construction if ``config`` is None. + Accepted keys: ``num_envs``, ``solver``, ``device``, etc. + """ + + def __init__( + self, + config: NewtonConfig | None = None, + **kwargs: Any, + ) -> None: + if config is not None: + self._config = config + elif kwargs: + # Allow create_simulation("newton", num_envs=4096) + self._config = NewtonConfig(**kwargs) + else: + self._config = NewtonConfig() + + logger.info( + "NewtonSimulation created (solver=%s, device=%s, num_envs=%d)", + self._config.solver, + self._config.device, + self._config.num_envs, + ) + + # ------------------------------------------------------------------ + # World lifecycle (stubs) + # ------------------------------------------------------------------ + + def create_world( + self, + timestep: float | None = None, + gravity: list[float] | None = None, + ground_plane: bool = True, + ) -> dict[str, Any]: + """Create a new simulation world. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton create_world not yet implemented") + + def destroy(self) -> dict[str, Any]: + """Destroy the simulation world and release resources. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton destroy not yet implemented") + + def reset(self) -> dict[str, Any]: + """Reset simulation to initial state. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton reset not yet implemented") + + def step(self, n_steps: int = 1) -> dict[str, Any]: + """Advance simulation by *n_steps* physics steps. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton step not yet implemented") + + def get_state(self) -> dict[str, Any]: + """Get full simulation state summary. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton get_state not yet implemented") + + # ------------------------------------------------------------------ + # Robot management (stubs) + # ------------------------------------------------------------------ + + def add_robot( + self, + name: str, + urdf_path: str | None = None, + data_config: str | None = None, + position: list[float] | None = None, + orientation: list[float] | None = None, + ) -> dict[str, Any]: + """Add a robot to the simulation. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton add_robot not yet implemented") + + def remove_robot(self, name: str) -> dict[str, Any]: + """Remove a robot from the simulation. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton remove_robot not yet implemented") + + # ------------------------------------------------------------------ + # Object management (stubs) + # ------------------------------------------------------------------ + + def add_object( + self, + name: str, + shape: str = "box", + position: list[float] | None = None, + orientation: list[float] | None = None, + size: list[float] | None = None, + color: list[float] | None = None, + mass: float = 0.1, + is_static: bool = False, + mesh_path: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Add an object to the scene. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton add_object not yet implemented") + + def remove_object(self, name: str) -> dict[str, Any]: + """Remove an object from the scene. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton remove_object not yet implemented") + + # ------------------------------------------------------------------ + # Observation / Action (stubs) + # ------------------------------------------------------------------ + + def get_observation( + self, + robot_name: str | None = None, + camera_name: str | None = None, + ) -> dict[str, Any]: + """Get observation from simulation. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton get_observation not yet implemented") + + def send_action( + self, + action: dict[str, Any], + robot_name: str | None = None, + n_substeps: int = 1, + ) -> None: + """Apply action to simulation. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton send_action not yet implemented") + + # ------------------------------------------------------------------ + # Rendering (stub) + # ------------------------------------------------------------------ + + def render( + self, + camera_name: str = "default", + width: int | None = None, + height: int | None = None, + ) -> dict[str, Any]: + """Render a camera view. + + .. note:: Stub — will be implemented in a follow-up PR. + """ + raise NotImplementedError("Newton render not yet implemented") + + # ------------------------------------------------------------------ + # Optional overrides (stubs for Newton-specific features) + # ------------------------------------------------------------------ + + def cleanup(self) -> None: + """Release all resources.""" + logger.debug("NewtonSimulation cleanup (stub)") + + def __repr__(self) -> str: + return ( + f"NewtonSimulation(solver={self._config.solver!r}, " + f"device={self._config.device!r}, " + f"num_envs={self._config.num_envs})" + ) diff --git a/strands_robots/simulation/newton/solvers.py b/strands_robots/simulation/newton/solvers.py new file mode 100644 index 0000000..d7f5934 --- /dev/null +++ b/strands_robots/simulation/newton/solvers.py @@ -0,0 +1,39 @@ +"""Newton solver map and backend constants. + +No heavy imports — this module loads instantly. +""" + +from __future__ import annotations + +# Maps user-facing solver name → Newton solver class name. +# Validated during GTC on Jetson AGX Thor (9/14 subtests passed): +# ✅ mujoco, semi_implicit, xpbd +# ❌ featherstone (Warp 1.11 ABI) +# ⚠️ vbd, style3d, implicit_mpm (soft-body/cloth/granular only) +SOLVER_MAP: dict[str, str] = { + "mujoco": "SolverMuJoCo", + "featherstone": "SolverFeatherstone", + "semi_implicit": "SolverSemiImplicit", + "xpbd": "SolverXPBD", + "vbd": "SolverVBD", + "style3d": "SolverStyle3D", + "implicit_mpm": "SolverImplicitMPM", +} + +RENDER_BACKENDS: frozenset[str] = frozenset( + { + "opengl", + "rerun", + "viser", + "null", + "none", + } +) + +BROAD_PHASE_OPTIONS: frozenset[str] = frozenset( + { + "sap", + "bvh", + "none", + } +) diff --git a/strands_robots/tools/__init__.py b/strands_robots/tools/__init__.py index c18ccbb..7ae62c0 100644 --- a/strands_robots/tools/__init__.py +++ b/strands_robots/tools/__init__.py @@ -12,6 +12,7 @@ import importlib as _importlib _LAZY_IMPORTS: dict[str, tuple[str, str]] = { + "download_assets": (".download_assets", "download_assets"), "gr00t_inference": (".gr00t_inference", "gr00t_inference"), "lerobot_calibrate": (".lerobot_calibrate", "lerobot_calibrate"), "lerobot_camera": (".lerobot_camera", "lerobot_camera"), diff --git a/strands_robots/tools/download_assets.py b/strands_robots/tools/download_assets.py new file mode 100644 index 0000000..2f59adf --- /dev/null +++ b/strands_robots/tools/download_assets.py @@ -0,0 +1,84 @@ +"""Download robot model assets — Strands Agent ``@tool`` wrapper. + +Thin wrapper around :mod:`strands_robots.assets.download` that exposes +``download_robots()`` as an agent tool. All download logic lives in the +``assets.download`` module; this file only handles input parsing and +output formatting for the Strands Agent SDK. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from strands.tools.decorator import tool + +from strands_robots.assets.download import download_robots, get_user_assets_dir +from strands_robots.assets.manager import list_available_robots +from strands_robots.registry import format_robot_table + +logger = logging.getLogger(__name__) + + +@tool +def download_assets( + action: str = "download", + robots: str | None = None, + category: str | None = None, + force: bool = False, +) -> dict[str, Any]: + """Download and manage robot model assets (MJCF XML + meshes). + + Assets are sourced from ``robot_descriptions`` (recommended by MuJoCo + Menagerie, requires ``pip install strands-robots[sim-mujoco]``). When + ``robot_descriptions`` is unavailable, falls back to a shallow + ``git clone`` of the Menagerie repo. Robots with a custom GitHub + source in the registry are cloned from their respective repos. + + Downloaded assets are cached in ``~/.strands_robots/assets/`` + (override with ``STRANDS_ASSETS_DIR``). + + Args: + action: ``download`` | ``list`` | ``status`` + robots: Comma-separated names (e.g. ``so100,panda``). Omit for all. + category: Filter: arm, bimanual, hand, humanoid, mobile, mobile_manip + force: Re-download even if present + """ + try: + if action == "list": + return { + "status": "success", + "content": [{"text": f"🤖 Available Robots:\n\n{format_robot_table()}"}], + } + + if action == "status": + robots_info = list_available_robots() + available = sum(1 for r in robots_info if r["available"]) + lines = [f"📊 {available} available, {len(robots_info) - available} missing"] + lines.extend( + f" {'✅' if r['available'] else '❌'} {r['name']:<20s} {r['category']:<12s} {r['description']}" + for r in robots_info + ) + lines.append(f"\n📁 Cache: {get_user_assets_dir()}") + return {"status": "success", "content": [{"text": "\n".join(lines)}]} + + if action == "download": + robot_names = [r.strip() for r in robots.split(",") if r.strip()] if robots else None + result = download_robots(names=robot_names, category=category, force=force) + parts = [ + f"📦 Downloaded: {result['downloaded']}, Skipped: {result['skipped']}, Failed: {result['failed']}", + f"Method: {result.get('method', '?')}", + ] + if result.get("failed_details"): + parts.extend(f" ❌ {n}: {r}" for n, r in result["failed_details"].items()) + parts.append(f"📁 Assets: {result.get('assets_dir', '?')}") + return {"status": "success", "content": [{"text": "\n".join(parts)}]} + + return { + "status": "error", + "content": [{"text": f"Unknown action: {action}. Valid: download, list, status"}], + } + + except Exception as exc: + logger.error("download_assets error: %s", exc) + return {"status": "error", "content": [{"text": f"❌ Error: {exc}"}]} diff --git a/strands_robots/utils.py b/strands_robots/utils.py index f2a930c..4a9cb71 100644 --- a/strands_robots/utils.py +++ b/strands_robots/utils.py @@ -2,6 +2,8 @@ import importlib import logging +import os +from pathlib import Path logger = logging.getLogger(__name__) @@ -49,3 +51,71 @@ def require_optional( parts.append(f" pip install 'strands-robots[{extra}]'") parts.append(f" pip install {install_hint}") raise ImportError("\n".join(parts)) from None + + +# ───────────────────────────────────────────────────────────────────── +# Path resolution — single source of truth for all strands-robots paths +# ───────────────────────────────────────────────────────────────────── + +#: Default base directory for all user data. +DEFAULT_BASE_DIR = Path.home() / ".strands_robots" + + +def get_base_dir() -> Path: + """Get the base directory for strands-robots user data. + + If ``STRANDS_ASSETS_DIR`` is set, returns its parent + (the assets dir is a subdirectory of the base). + Otherwise returns ``~/.strands_robots/``. + + Returns: + Path to the base directory (created if needed). + """ + custom = os.getenv("STRANDS_ASSETS_DIR") + if custom: + d = Path(custom).parent + else: + d = DEFAULT_BASE_DIR + d.mkdir(parents=True, exist_ok=True) + return d + + +def get_assets_dir() -> Path: + """Get the assets directory (robot model files, meshes, URDFs). + + Resolution: + 1. ``STRANDS_ASSETS_DIR`` env var — used as-is + 2. ``~/.strands_robots/assets/`` — default + + Returns: + Path to the assets directory (created if needed). + """ + custom = os.getenv("STRANDS_ASSETS_DIR") + if custom: + d = Path(custom) + else: + d = DEFAULT_BASE_DIR / "assets" + d.mkdir(parents=True, exist_ok=True) + return d + + +def resolve_asset_path(relative_or_absolute: str | Path | None, default_name: str = "") -> Path: + """Resolve an asset path against the assets directory. + + Args: + relative_or_absolute: Path to resolve. + - ``None`` → ``//`` + - Absolute (or ``~/...``) → expanded as-is + - Relative → ``//`` + default_name: Fallback subdirectory name when path is None. + + Returns: + Resolved absolute Path. + """ + assets = get_assets_dir() + if relative_or_absolute is None: + return assets / default_name + expanded = Path(relative_or_absolute).expanduser() + if expanded.is_absolute(): + return expanded + return assets / expanded diff --git a/tests/simulation/__init__.py b/tests/simulation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/simulation/newton/__init__.py b/tests/simulation/newton/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/simulation/newton/test_newton_stub.py b/tests/simulation/newton/test_newton_stub.py new file mode 100644 index 0000000..d2b6a91 --- /dev/null +++ b/tests/simulation/newton/test_newton_stub.py @@ -0,0 +1,216 @@ +"""Tests for Newton backend stub — config, factory registration, lazy imports. + +These tests require NO GPU and NO warp/newton installation. +They verify the class hierarchy, factory resolution, and configuration +validation that runs before any physics engine is touched. +""" + +from __future__ import annotations + +import pytest + +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.factory import ( + _BUILTIN_ALIASES, + _BUILTIN_BACKENDS, + _resolve_name, + create_simulation, + list_backends, +) +from strands_robots.simulation.newton.config import NewtonConfig +from strands_robots.simulation.newton.simulation import NewtonSimulation +from strands_robots.simulation.newton.solvers import ( + BROAD_PHASE_OPTIONS, + RENDER_BACKENDS, + SOLVER_MAP, +) + +# ── Factory registration ────────────────────────────────────────────── + + +class TestFactoryRegistration: + """Verify Newton is registered in the simulation factory.""" + + def test_newton_in_builtin_backends(self) -> None: + assert "newton" in _BUILTIN_BACKENDS + module_path, class_name = _BUILTIN_BACKENDS["newton"] + assert module_path == "strands_robots.simulation.newton.simulation" + assert class_name == "NewtonSimulation" + + def test_warp_alias_resolves_to_newton(self) -> None: + assert "warp" in _BUILTIN_ALIASES + assert _BUILTIN_ALIASES["warp"] == "newton" + assert _resolve_name("warp") == "newton" + + def test_list_backends_includes_newton(self) -> None: + backends = list_backends() + assert "newton" in backends + assert "warp" in backends + + def test_create_simulation_newton(self) -> None: + sim = create_simulation("newton") + assert isinstance(sim, NewtonSimulation) + assert isinstance(sim, SimEngine) + + def test_create_simulation_warp_alias(self) -> None: + sim = create_simulation("warp") + assert isinstance(sim, NewtonSimulation) + + def test_create_simulation_with_kwargs(self) -> None: + sim = create_simulation("newton", num_envs=4096, solver="xpbd") + assert isinstance(sim, NewtonSimulation) + assert sim._config.num_envs == 4096 + assert sim._config.solver == "xpbd" + + +# ── NewtonSimulation class ──────────────────────────────────────────── + + +class TestNewtonSimulation: + """Verify the stub class hierarchy and behaviour.""" + + def test_is_simengine_subclass(self) -> None: + assert issubclass(NewtonSimulation, SimEngine) + + def test_default_construction(self) -> None: + sim = NewtonSimulation() + assert sim._config.solver == "mujoco" + assert sim._config.device == "cuda:0" + assert sim._config.num_envs == 1 + + def test_construction_with_config(self) -> None: + cfg = NewtonConfig(solver="xpbd", num_envs=64, device="cpu") + sim = NewtonSimulation(config=cfg) + assert sim._config.solver == "xpbd" + assert sim._config.num_envs == 64 + + def test_construction_with_kwargs(self) -> None: + sim = NewtonSimulation(num_envs=256, solver="semi_implicit") + assert sim._config.num_envs == 256 + assert sim._config.solver == "semi_implicit" + + def test_repr(self) -> None: + sim = NewtonSimulation() + r = repr(sim) + assert "NewtonSimulation" in r + assert "mujoco" in r + + def test_context_manager(self) -> None: + with NewtonSimulation() as sim: + assert isinstance(sim, NewtonSimulation) + + def test_cleanup_does_not_raise(self) -> None: + sim = NewtonSimulation() + sim.cleanup() # Should be a no-op, not raise + + @pytest.mark.parametrize( + "method,args", + [ + ("create_world", ()), + ("destroy", ()), + ("reset", ()), + ("step", ()), + ("get_state", ()), + ("add_robot", ("test_robot",)), + ("remove_robot", ("test_robot",)), + ("add_object", ("test_obj",)), + ("remove_object", ("test_obj",)), + ("get_observation", ()), + ("send_action", ({"joint_0": 0.5},)), + ("render", ()), + ], + ) + def test_abstract_methods_raise_not_implemented(self, method: str, args: tuple) -> None: + sim = NewtonSimulation() + with pytest.raises(NotImplementedError, match="Newton"): + getattr(sim, method)(*args) + + +# ── NewtonConfig validation ─────────────────────────────────────────── + + +class TestNewtonConfig: + """Verify config validates inputs at construction time.""" + + def test_default_config(self) -> None: + cfg = NewtonConfig() + assert cfg.solver == "mujoco" + assert cfg.device == "cuda:0" + assert cfg.num_envs == 1 + assert cfg.physics_dt == 0.005 + assert cfg.substeps == 1 + assert cfg.render_backend == "null" + + def test_all_solvers_accepted(self) -> None: + for solver in SOLVER_MAP: + cfg = NewtonConfig(solver=solver) + assert cfg.solver == solver + + def test_invalid_solver_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown solver"): + NewtonConfig(solver="nonexistent") + + def test_invalid_render_backend_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown render_backend"): + NewtonConfig(render_backend="vulkan") + + def test_invalid_broad_phase_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown broad_phase"): + NewtonConfig(broad_phase="octree") + + def test_negative_dt_raises(self) -> None: + with pytest.raises(ValueError, match="physics_dt must be positive"): + NewtonConfig(physics_dt=-0.001) + + def test_zero_dt_raises(self) -> None: + with pytest.raises(ValueError, match="physics_dt must be positive"): + NewtonConfig(physics_dt=0.0) + + def test_zero_envs_raises(self) -> None: + with pytest.raises(ValueError, match="num_envs must be >= 1"): + NewtonConfig(num_envs=0) + + +# ── Solver map constants ────────────────────────────────────────────── + + +class TestSolverConstants: + """Verify solver map and constant sets are well-formed.""" + + def test_solver_map_has_seven_entries(self) -> None: + assert len(SOLVER_MAP) == 7 + + def test_expected_solvers_present(self) -> None: + expected = {"mujoco", "featherstone", "semi_implicit", "xpbd", "vbd", "style3d", "implicit_mpm"} + assert set(SOLVER_MAP.keys()) == expected + + def test_render_backends_includes_null(self) -> None: + assert "null" in RENDER_BACKENDS + assert "none" in RENDER_BACKENDS + assert "opengl" in RENDER_BACKENDS + + def test_broad_phase_includes_sap(self) -> None: + assert "sap" in BROAD_PHASE_OPTIONS + + +# ── Lazy import guard ───────────────────────────────────────────────── + + +class TestLazyImports: + """Verify importing the newton package does NOT trigger warp/newton loads.""" + + def test_import_init_does_not_load_warp(self) -> None: + import sys + + # If warp were eagerly imported, it would be in sys.modules + # after importing the newton package. Since we don't have + # warp installed in CI, an eager import would raise ImportError. + import strands_robots.simulation.newton # noqa: F401 + + # Should succeed without warp being present + assert "strands_robots.simulation.newton" in sys.modules + + def test_import_config_is_lightweight(self) -> None: + import strands_robots.simulation.newton.config # noqa: F401 + import strands_robots.simulation.newton.solvers # noqa: F401 + # These must succeed with zero heavy deps diff --git a/tests/test_simulation_foundation.py b/tests/test_simulation_foundation.py new file mode 100644 index 0000000..c06fcf6 --- /dev/null +++ b/tests/test_simulation_foundation.py @@ -0,0 +1,359 @@ +"""Tests for simulation foundation — models, ABC, factory, model_registry. + +These tests verify the lightweight simulation abstractions without +requiring MuJoCo or any heavy dependencies. +""" + +import pytest + +from strands_robots.simulation.base import SimEngine +from strands_robots.simulation.factory import ( + create_simulation, + list_backends, + register_backend, +) +from strands_robots.simulation.models import ( + SimObject, + SimRobot, + SimStatus, + SimWorld, + TrajectoryStep, +) + +# ── ABC Tests ──────────────────────────────────────────────────── + + +class TestSimEngine: + """Test the abstract base class contract.""" + + def test_cannot_instantiate_abc(self): + with pytest.raises(TypeError): + SimEngine() + + def test_has_required_abstract_methods(self): + abstract_methods = SimEngine.__abstractmethods__ + expected = { + "create_world", + "destroy", + "reset", + "step", + "get_state", + "add_robot", + "remove_robot", + "add_object", + "remove_object", + "get_observation", + "send_action", + "render", + } + assert expected == abstract_methods + + def test_optional_methods_raise_not_implemented(self): + """Optional methods on a concrete subclass raise NotImplementedError.""" + + class Dummy(SimEngine): + def create_world(self, **kw): + return {} + + def destroy(self): + return {} + + def reset(self): + return {} + + def step(self, n_steps=1): + return {} + + def get_state(self): + return {} + + def add_robot(self, name, **kw): + return {} + + def remove_robot(self, name): + return {} + + def add_object(self, name, **kw): + return {} + + def remove_object(self, name): + return {} + + def get_observation(self, **kw): + return {} + + def send_action(self, action, **kw): + return None + + def render(self, **kw): + return {} + + d = Dummy() + with pytest.raises(NotImplementedError): + d.load_scene("x") + with pytest.raises(NotImplementedError): + d.run_policy("x") + with pytest.raises(NotImplementedError): + d.randomize() + with pytest.raises(NotImplementedError): + d.get_contacts() + + def test_context_manager_calls_cleanup(self): + """ABC supports context manager protocol and calls cleanup on exit.""" + + class Dummy(SimEngine): + cleaned = False + + def create_world(self, **kw): + return {} + + def destroy(self): + return {} + + def reset(self): + return {} + + def step(self, n_steps=1): + return {} + + def get_state(self): + return {} + + def add_robot(self, name, **kw): + return {} + + def remove_robot(self, name): + return {} + + def add_object(self, name, **kw): + return {} + + def remove_object(self, name): + return {} + + def get_observation(self, **kw): + return {} + + def send_action(self, action, **kw): + return None + + def render(self, **kw): + return {} + + def cleanup(self): + Dummy.cleaned = True + + with Dummy() as _d: + pass + assert Dummy.cleaned is True + + +# ── Factory Tests ──────────────────────────────────────────────── + + +class TestSimulationFactory: + """Test backend registration and creation — full round-trip.""" + + def test_list_backends_includes_mujoco(self): + backends = list_backends() + assert "mujoco" in backends + + def test_register_create_and_use_backend(self): + """Register a custom backend, create it via factory, verify instance.""" + + class FakeBackend(SimEngine): + def create_world(self, **kw): + return {} + + def destroy(self): + return {} + + def reset(self): + return {} + + def step(self, n_steps=1): + return {} + + def get_state(self): + return {} + + def add_robot(self, name, **kw): + return {} + + def remove_robot(self, name): + return {} + + def add_object(self, name, **kw): + return {} + + def remove_object(self, name): + return {} + + def get_observation(self, **kw): + return {} + + def send_action(self, action, **kw): + return None + + def render(self, **kw): + return {} + + register_backend("fake_test", lambda: FakeBackend, force=True) + assert "fake_test" in list_backends() + sim = create_simulation("fake_test") + assert isinstance(sim, FakeBackend) + + def test_register_rejects_duplicate(self): + """Registering an existing name without force raises ValueError.""" + + class Dummy(SimEngine): + def create_world(self, **kw): + return {} + + def destroy(self): + return {} + + def reset(self): + return {} + + def step(self, n_steps=1): + return {} + + def get_state(self): + return {} + + def add_robot(self, name, **kw): + return {} + + def remove_robot(self, name): + return {} + + def add_object(self, name, **kw): + return {} + + def remove_object(self, name): + return {} + + def get_observation(self, **kw): + return {} + + def send_action(self, action, **kw): + return None + + def render(self, **kw): + return {} + + register_backend("dup_test", lambda: Dummy, force=True) + with pytest.raises(ValueError, match="already registered"): + register_backend("dup_test", lambda: Dummy) + + def test_register_rejects_builtin_alias(self): + """Cannot hijack built-in aliases like 'mj'.""" + + class Dummy(SimEngine): + def create_world(self, **kw): + return {} + + def destroy(self): + return {} + + def reset(self): + return {} + + def step(self, n_steps=1): + return {} + + def get_state(self): + return {} + + def add_robot(self, name, **kw): + return {} + + def remove_robot(self, name): + return {} + + def add_object(self, name, **kw): + return {} + + def remove_object(self, name): + return {} + + def get_observation(self, **kw): + return {} + + def send_action(self, action, **kw): + return None + + def render(self, **kw): + return {} + + with pytest.raises(ValueError, match="conflicts with built-in"): + register_backend("custom_phys", lambda: Dummy, aliases=["mj"]) + + +# ── Model Registry Tests ───────────────────────────────────────── + + +class TestModelRegistry: + """Test URDF/MJCF model resolution.""" + + def test_list_available_models_returns_robot_table(self): + from strands_robots.simulation.model_registry import list_available_models + + models = list_available_models() + assert isinstance(models, str) + assert "so100" in models + assert len(models) > 100 + + def test_register_and_resolve_urdf(self, tmp_path): + """Register a URDF, resolve it back — full round-trip.""" + from strands_robots.simulation.model_registry import register_urdf, resolve_urdf + + urdf_file = tmp_path / "robot.urdf" + urdf_file.write_text("") + register_urdf("test_robot_xyz", str(urdf_file)) + result = resolve_urdf("test_robot_xyz") + assert result == str(urdf_file) + + def test_list_registered_urdfs(self): + from strands_robots.simulation.model_registry import list_registered_urdfs, register_urdf + + register_urdf("list_test_bot", "/fake/list.urdf") + urdfs = list_registered_urdfs() + assert isinstance(urdfs, dict) + assert "list_test_bot" in urdfs + + +# ── Dataclass Behavioral Tests ─────────────────────────────────── + + +class TestSimModelsUsage: + """Test that simulation models behave correctly in real usage patterns.""" + + def test_sim_world_tracks_robots(self): + """SimWorld can add robots and objects — simulates real world setup.""" + world = SimWorld() + robot = SimRobot(name="so100", urdf_path="/p") + world.robots["so100"] = robot + assert "so100" in world.robots + assert world.status == SimStatus.IDLE + + def test_sim_object_preserves_originals_for_randomization(self): + """SimObject stores original position/color for domain randomization reset.""" + obj = SimObject(name="ball", shape="sphere", position=[1, 2, 3], color=[1, 0, 0, 1]) + assert obj._original_position == [1, 2, 3] + assert obj._original_color == [1, 0, 0, 1] + + def test_trajectory_step_records_episode_data(self): + """TrajectoryStep captures full observation-action pair for dataset recording.""" + step = TrajectoryStep( + timestamp=1.0, + sim_time=0.5, + robot_name="arm", + observation={"state": [1, 2, 3]}, + action={"joint_0": 0.5}, + instruction="pick up cube", + ) + assert step.robot_name == "arm" + assert step.instruction == "pick up cube" + assert step.observation["state"] == [1, 2, 3] diff --git a/tests/test_user_registry.py b/tests/test_user_registry.py new file mode 100644 index 0000000..ee05e2f --- /dev/null +++ b/tests/test_user_registry.py @@ -0,0 +1,364 @@ +"""Tests for user-local robot registry and shared path utilities. + +Covers: + - strands_robots.registry.user_registry (register, unregister, list, persistence) + - strands_robots.registry.loader._merge_user_robots (user overlay merge) + - strands_robots.utils (get_base_dir, get_assets_dir, resolve_asset_path) +""" + +import json +import logging +import os +from pathlib import Path +from unittest import mock + +import pytest + +from strands_robots.registry import get_robot, list_robots, resolve_name +from strands_robots.registry.user_registry import ( + _get_user_registry_path, + _invalidate_cache, + _load_user_registry, + get_user_robots, + list_user_robots, + register_robot, + unregister_robot, +) +from strands_robots.utils import get_assets_dir, get_base_dir, resolve_asset_path + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_MINIMAL_MJCF = '' + + +@pytest.fixture(autouse=True) +def _isolate_registry(tmp_path, monkeypatch): + """Point STRANDS_ASSETS_DIR to a temp dir and clear caches for every test.""" + assets_dir = tmp_path / "assets" + assets_dir.mkdir() + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(assets_dir)) + _invalidate_cache() + yield + _invalidate_cache() + + +def _make_robot(parent: Path, name: str = "test_bot", xml_name: str = "bot.xml") -> Path: + """Create a minimal MJCF robot directory and return its path.""" + d = parent / name + d.mkdir(parents=True, exist_ok=True) + (d / xml_name).write_text(_MINIMAL_MJCF) + return d + + +# =========================================================================== +# Registration +# =========================================================================== + + +class TestRegisterRobot: + """register_robot() stores metadata and makes the robot discoverable.""" + + def test_stores_description_category_and_joints(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + entry = register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + description="A test bot", + category="arm", + joints=3, + ) + assert entry["description"] == "A test bot" + assert entry["category"] == "arm" + assert entry["joints"] == 3 + assert entry["asset"]["model_xml"] == "bot.xml" + + def test_robot_visible_via_get_robot(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + assert get_robot("test_bot") is not None + + def test_robot_visible_in_list_robots(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + assert "test_bot" in [r["name"] for r in list_robots()] + + def test_aliases_resolve_to_canonical_name(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + aliases=["my_bot", "tb"], + ) + assert resolve_name("my_bot") == "test_bot" + assert resolve_name("tb") == "test_bot" + + def test_stores_robot_descriptions_module(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + entry = register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + robot_descriptions_module="my_pkg.test_bot", + ) + assert entry["asset"]["robot_descriptions_module"] == "my_pkg.test_bot" + + def test_stores_hardware_config(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + hw = {"lerobot_type": "so100_follower", "cameras": {"top": 0}} + entry = register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + hardware=hw, + ) + assert entry["hardware"] == hw + + +class TestRegisterRobotNameNormalization: + """Names are lower-cased, stripped, and hyphens become underscores.""" + + def test_normalizes_whitespace_hyphens_and_case(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name=" My-Bot ", model_xml="bot.xml", asset_dir=str(robot_dir)) + assert get_robot("my_bot") is not None + + +class TestRegisterRobotDuplicates: + """Duplicate handling: raise by default, allow with overwrite=True.""" + + def test_duplicate_raises_by_default(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + with pytest.raises(ValueError, match="already in user registry"): + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + + def test_overwrite_replaces_existing(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir), description="v1") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir), description="v2", overwrite=True) + assert get_robot("test_bot")["description"] == "v2" + + def test_overriding_package_robot_logs_info(self, tmp_path, caplog): + """Registering a name that exists in the package registry emits an info log.""" + panda_dir = _make_robot(tmp_path / "assets", name="panda", xml_name="panda.xml") + with caplog.at_level(logging.INFO, logger="strands_robots.registry.user_registry"): + register_robot( + name="panda", + model_xml="panda.xml", + asset_dir=str(panda_dir), + description="Custom panda", + ) + assert any("exists in package registry" in m for m in caplog.messages) + assert get_robot("panda")["description"] == "Custom panda" + unregister_robot("panda") + + +class TestRegisterRobotValidation: + """register_robot rejects invalid inputs.""" + + def test_missing_model_xml_raises_file_not_found(self, tmp_path): + empty_dir = tmp_path / "assets" / "empty" + empty_dir.mkdir(parents=True) + with pytest.raises(FileNotFoundError, match="Model XML not found"): + register_robot(name="empty", model_xml="nope.xml", asset_dir=str(empty_dir)) + + +class TestRegisterRobotAssetDirResolution: + """asset_dir is resolved relative to STRANDS_ASSETS_DIR.""" + + def test_none_defaults_to_assets_subdir(self, tmp_path): + default_dir = _make_robot(tmp_path / "assets", name="auto_bot", xml_name="auto.xml") + entry = register_robot(name="auto_bot", model_xml="auto.xml") + assert entry["_user_asset_path"] == str(default_dir) + + def test_relative_path_resolved_against_assets(self, tmp_path): + rel_dir = tmp_path / "assets" / "sub" / "bot" + rel_dir.mkdir(parents=True) + (rel_dir / "r.xml").write_text(_MINIMAL_MJCF) + entry = register_robot(name="rel_bot", model_xml="r.xml", asset_dir="sub/bot") + assert entry["_user_asset_path"] == str(rel_dir) + + def test_absolute_path_used_as_is(self, tmp_path): + abs_dir = tmp_path / "elsewhere" / "bot" + abs_dir.mkdir(parents=True) + (abs_dir / "b.xml").write_text(_MINIMAL_MJCF) + entry = register_robot(name="abs_bot", model_xml="b.xml", asset_dir=str(abs_dir)) + assert entry["_user_asset_path"] == str(abs_dir) + + +# =========================================================================== +# Unregistration +# =========================================================================== + + +class TestUnregisterRobot: + """unregister_robot() removes from user registry only.""" + + def test_removes_registered_robot(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + assert unregister_robot("test_bot") is True + assert get_user_robots().get("test_bot") is None + + def test_returns_false_for_nonexistent(self): + assert unregister_robot("nonexistent") is False + + def test_does_not_affect_package_robots(self): + assert get_robot("panda") is not None + assert unregister_robot("panda") is False + assert get_robot("panda") is not None + + +# =========================================================================== +# Listing +# =========================================================================== + + +class TestListUserRobots: + """list_user_robots() returns user-registered robots only.""" + + def test_empty_when_nothing_registered(self): + assert list_user_robots() == [] + + def test_returns_registered_robot_metadata(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot( + name="test_bot", + model_xml="bot.xml", + asset_dir=str(robot_dir), + description="Desc", + joints=5, + ) + result = list_user_robots() + assert len(result) == 1 + assert result[0]["name"] == "test_bot" + assert result[0]["description"] == "Desc" + assert result[0]["joints"] == 5 + assert result[0]["model_xml"] == "bot.xml" + + +# =========================================================================== +# Persistence +# =========================================================================== + + +class TestPersistence: + """User registry persists to a JSON file and survives corruption.""" + + def test_writes_json_file(self, tmp_path): + robot_dir = _make_robot(tmp_path / "assets") + register_robot(name="test_bot", model_xml="bot.xml", asset_dir=str(robot_dir)) + path = _get_user_registry_path() + assert path.exists() + data = json.loads(path.read_text()) + assert "test_bot" in data["robots"] + + def test_corrupted_json_returns_empty(self): + path = _get_user_registry_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("NOT JSON!!!") + assert _load_user_registry() == {"robots": {}} + + def test_valid_json_without_robots_key_returns_empty(self): + path = _get_user_registry_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text('{"version": 1}') + assert _load_user_registry() == {"robots": {}} + + +# =========================================================================== +# Loader merge +# =========================================================================== + + +class TestLoaderMerge: + """_merge_user_robots gracefully handles missing user_registry module.""" + + def test_import_error_returns_data_unchanged(self): + from strands_robots.registry.loader import _merge_user_robots + + data = {"robots": {"fake": {"description": "test"}}} + with mock.patch.dict("sys.modules", {"strands_robots.registry.user_registry": None}): + result = _merge_user_robots(data) + assert "fake" in result["robots"] + + +# =========================================================================== +# STRANDS_ASSETS_DIR integration +# =========================================================================== + + +class TestStrandsAssetsDirIntegration: + """Registry file location respects STRANDS_ASSETS_DIR env var.""" + + def test_registry_file_in_parent_of_assets_dir(self, tmp_path): + custom = tmp_path / "custom_assets" + custom.mkdir() + with mock.patch.dict(os.environ, {"STRANDS_ASSETS_DIR": str(custom)}): + assert _get_user_registry_path().parent == custom.parent + + def test_defaults_to_dot_strands_robots(self, monkeypatch): + monkeypatch.delenv("STRANDS_ASSETS_DIR", raising=False) + assert ".strands_robots" in str(_get_user_registry_path()) + + +# =========================================================================== +# Path utilities (strands_robots.utils) +# =========================================================================== + + +class TestGetAssetsDir: + """get_assets_dir() returns STRANDS_ASSETS_DIR or ~/.strands_robots/assets/.""" + + def test_default(self, monkeypatch): + monkeypatch.delenv("STRANDS_ASSETS_DIR", raising=False) + result = get_assets_dir() + assert str(result).endswith("assets") + assert ".strands_robots" in str(result) + + def test_custom(self, tmp_path, monkeypatch): + custom = tmp_path / "my_assets" + custom.mkdir() + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(custom)) + assert get_assets_dir() == custom + + +class TestGetBaseDir: + """get_base_dir() returns parent of STRANDS_ASSETS_DIR or ~/.strands_robots/.""" + + def test_default(self, monkeypatch): + monkeypatch.delenv("STRANDS_ASSETS_DIR", raising=False) + assert str(get_base_dir()).endswith(".strands_robots") + + def test_custom(self, tmp_path, monkeypatch): + custom = tmp_path / "custom_assets" + custom.mkdir() + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(custom)) + assert get_base_dir() == tmp_path + + +class TestResolveAssetPath: + """resolve_asset_path() resolves None, relative, absolute, and ~ paths.""" + + def test_none_returns_assets_dir_plus_default_name(self, tmp_path, monkeypatch): + assets = tmp_path / "a" + assets.mkdir() + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(assets)) + assert resolve_asset_path(None, "robot") == assets / "robot" + + def test_relative_resolved_against_assets_dir(self, tmp_path, monkeypatch): + assets = tmp_path / "a" + assets.mkdir() + monkeypatch.setenv("STRANDS_ASSETS_DIR", str(assets)) + assert resolve_asset_path("sub/dir") == assets / "sub" / "dir" + + def test_absolute_path_unchanged(self): + assert resolve_asset_path("/absolute/path") == Path("/absolute/path") + + def test_tilde_expanded(self): + result = resolve_asset_path("~/robots") + assert str(result).startswith(str(Path.home()))