diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 4921e64..10c23cd 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -7,6 +7,7 @@ """ import os +import re import socket import subprocess import time @@ -14,6 +15,82 @@ from strands import tool +# ───────────────────────────────────────────────────────────────────── +# Input validation helpers +# ───────────────────────────────────────────────────────────────────── + +# Characters that must never appear in values interpolated into commands. +_SHELL_META = re.compile(r"[;&|`$(){}\[\]!<>\\'\"\n\r\x00]") + +# Strict patterns for enumerable parameters. +_DATA_CONFIG_RE = re.compile(r"^[a-z][a-z0-9_]{0,63}$") +_EMBODIMENT_TAG_RE = re.compile(r"^[a-z][a-z0-9_]{0,31}$") +_CONTAINER_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}$") + +# Allowlists for TensorRT dtype parameters. +_VALID_VIT_DTYPES = {"fp16", "fp8"} +_VALID_LLM_DTYPES = {"fp16", "nvfp4", "fp8"} +_VALID_DIT_DTYPES = {"fp16", "fp8"} + + +def _validate_path(value: str, label: str) -> None: + """Reject paths containing shell metacharacters, null bytes, or traversal sequences.""" + if "\x00" in value: + raise ValueError(f"{label} must not contain null bytes") + if ".." in value.split("/"): + raise ValueError(f"{label} must not contain '..' path traversal components") + if _SHELL_META.search(value): + raise ValueError(f"{label} contains disallowed characters: {value!r}") + + +def validate_inputs( + *, + data_config: str, + embodiment_tag: str, + port: int, + vit_dtype: str, + llm_dtype: str, + dit_dtype: str, + checkpoint_path: str | None = None, + trt_engine_path: str = "gr00t_engine", + container_name: str | None = None, +) -> None: + """Validate all user-supplied parameters in one place. + + Raises ValueError for any invalid input. This centralises validation so + that the main tool function stays focused on orchestration and each + check is independently testable via this single entry-point. + """ + # Enumerable string parameters + if not _DATA_CONFIG_RE.match(data_config): + raise ValueError( + f"data_config must be lowercase alphanumeric/underscore (got {data_config!r}). " + f"See the tool docstring for the full list of accepted configs." + ) + if not _EMBODIMENT_TAG_RE.match(embodiment_tag): + raise ValueError(f"embodiment_tag must be lowercase alphanumeric/underscore (got {embodiment_tag!r})") + + # Docker container name + if container_name is not None and not _CONTAINER_NAME_RE.match(container_name): + raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})") + + # Filesystem paths — reject shell metacharacters and traversal + if checkpoint_path is not None: + _validate_path(checkpoint_path, "checkpoint_path") + _validate_path(trt_engine_path, "trt_engine_path") + + # TensorRT dtype allowlists + if vit_dtype not in _VALID_VIT_DTYPES: + raise ValueError(f"vit_dtype must be one of {_VALID_VIT_DTYPES}, got {vit_dtype!r}") + if llm_dtype not in _VALID_LLM_DTYPES: + raise ValueError(f"llm_dtype must be one of {_VALID_LLM_DTYPES}, got {llm_dtype!r}") + if dit_dtype not in _VALID_DIT_DTYPES: + raise ValueError(f"dit_dtype must be one of {_VALID_DIT_DTYPES}, got {dit_dtype!r}") + + # Port range + if not (1 <= port <= 65535): + raise ValueError(f"port must be between 1 and 65535, got {port}") + @tool def gr00t_inference( @@ -24,7 +101,7 @@ def gr00t_inference( data_config: str = "fourier_gr1_arms_only", embodiment_tag: str = "gr1", denoising_steps: int = 4, - host: str = "0.0.0.0", + host: str = "127.0.0.1", container_name: str | None = None, timeout: int = 60, use_tensorrt: bool = False, @@ -112,7 +189,7 @@ def gr00t_inference( data_config: Embodiment data config name (see Data configs above). embodiment_tag: Embodiment tag for the model (e.g., ``gr1``, ``so100``). denoising_steps: Number of denoising steps for action generation (default: 4). - host: Host address to bind the service to (default: ``0.0.0.0``). + host: Host address to bind the service to (default: ``127.0.0.1``). container_name: Specific Docker container name. Auto-detected if omitted. timeout: Seconds to wait for service startup (default: 60). use_tensorrt: Enable TensorRT acceleration (default: False). @@ -180,6 +257,19 @@ def gr00t_inference( if api_token is None: api_token = os.environ.get("GROOT_API_TOKEN") + # ── Validate all inputs in one call ─────────────────────────────── + validate_inputs( + data_config=data_config, + embodiment_tag=embodiment_tag, + port=port, + vit_dtype=vit_dtype, + llm_dtype=llm_dtype, + dit_dtype=dit_dtype, + checkpoint_path=checkpoint_path, + trt_engine_path=trt_engine_path, + container_name=container_name, + ) + if action == "find_containers": return _find_gr00t_containers() elif action == "list": @@ -314,6 +404,27 @@ def _check_service_status(port: int) -> dict[str, Any]: } +def _is_gr00t_process(container_name: str, pid: str) -> bool: + """Verify that a PID inside a container belongs to a GR00T inference process. + + This prevents accidentally killing unrelated processes that happen to + be listening on the same port. + """ + try: + result = subprocess.run( + ["docker", "exec", container_name, "cat", f"/proc/{pid}/cmdline"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + cmdline = result.stdout.replace("\x00", " ") + return "inference_service" in cmdline or "gr00t" in cmdline.lower() + except Exception: + pass + return False + + def _stop_service(port: int) -> dict[str, Any]: """Stop GR00T inference service running on specific port.""" try: @@ -334,13 +445,21 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid): subprocess.run(["docker", "exec", container_name, "kill", "-TERM", pid], check=True) time.sleep(2) result = subprocess.run( - ["docker", "exec", container_name, "pgrep", "-f", f"inference_service.py.*--port {port}"], + [ + "docker", + "exec", + container_name, + "pgrep", + "-f", + f"inference_service.py.*--port {port}", + ], capture_output=True, text=True, check=False, @@ -349,7 +468,8 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid): subprocess.run(["docker", "exec", container_name, "kill", "-KILL", pid], check=True) return { @@ -362,22 +482,32 @@ def _stop_service(port: int) -> dict[str, Any]: except subprocess.CalledProcessError: continue - # Fallback: try host system - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + # Fallback: try host system — only kill processes that match inference_service + result = subprocess.run( + ["pgrep", "-f", f"inference_service.py.*--port {port}"], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: + pid = pid.strip() if pid: subprocess.run(["kill", "-TERM", pid], check=True) time.sleep(2) - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + result = subprocess.run( + ["pgrep", "-f", f"inference_service.py.*--port {port}"], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: + pid = pid.strip() if pid: subprocess.run(["kill", "-KILL", pid], check=True) diff --git a/tests/groot/test_gr00t_inference_validation.py b/tests/groot/test_gr00t_inference_validation.py new file mode 100644 index 0000000..f059ce6 --- /dev/null +++ b/tests/groot/test_gr00t_inference_validation.py @@ -0,0 +1,194 @@ +"""Tests for gr00t_inference input validation.""" + +import pytest + +from strands_robots.tools.gr00t_inference import validate_inputs + + +class TestValidateInputs: + """Test the centralised validate_inputs() function.""" + + def test_valid_defaults(self): + """Default parameter values must pass validation.""" + validate_inputs( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + def test_valid_with_all_optional(self): + """All optional parameters provided with valid values.""" + validate_inputs( + data_config="so100_dualcam", + embodiment_tag="so100", + port=8000, + vit_dtype="fp16", + llm_dtype="fp16", + dit_dtype="fp16", + checkpoint_path="/data/checkpoints/model", + trt_engine_path="my_engine", + container_name="isaac-gr00t-1", + ) + + # ── data_config ────────────────────────────────────────────────── + + def test_invalid_data_config_uppercase(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + data_config="UPPER", + embodiment_tag="gr1", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + def test_invalid_data_config_shell_chars(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + data_config="config;rm -rf /", + embodiment_tag="gr1", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + # ── embodiment_tag ─────────────────────────────────────────────── + + def test_invalid_embodiment_tag(self): + with pytest.raises(ValueError, match="embodiment_tag"): + validate_inputs( + data_config="so100", + embodiment_tag="BAD TAG!", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + # ── port ───────────────────────────────────────────────────────── + + def test_port_zero(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=0, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + def test_port_too_high(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=70000, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + # ── dtype allowlists ───────────────────────────────────────────── + + def test_invalid_vit_dtype(self): + with pytest.raises(ValueError, match="vit_dtype"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="int8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + def test_invalid_llm_dtype(self): + with pytest.raises(ValueError, match="llm_dtype"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="bf16", + dit_dtype="fp8", + ) + + def test_invalid_dit_dtype(self): + with pytest.raises(ValueError, match="dit_dtype"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="nvfp4", + ) + + # ── path validation ────────────────────────────────────────────── + + def test_checkpoint_path_traversal(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/data/../../../etc/passwd", + ) + + def test_checkpoint_path_null_byte(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/data/model\x00.bin", + ) + + def test_trt_engine_path_shell_injection(self): + with pytest.raises(ValueError, match="trt_engine_path"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + trt_engine_path="engine; rm -rf /", + ) + + # ── container_name ─────────────────────────────────────────────── + + def test_invalid_container_name(self): + with pytest.raises(ValueError, match="container_name"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + container_name="-invalid", + ) + + def test_container_name_none_is_ok(self): + """container_name=None should pass (auto-detect).""" + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + container_name=None, + )