From 39c315331e2a532247be4a573aaee41a5ca479dc Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 1 Apr 2026 18:45:11 -0400 Subject: [PATCH 1/3] improve: stricter input validation and default loopback in gr00t_inference Improvements to the gr00t_inference tool: 1. Input validation for all user-supplied parameters: - data_config and embodiment_tag validated against strict alphanumeric patterns (they are enumerable values from the docstring). - checkpoint_path and trt_engine_path reject shell metacharacters, null bytes, and '..' traversal components. - container_name validated against Docker naming rules. - dtype values checked against explicit allowlists. - Port range validated (1-65535). 2. Default host changed from 0.0.0.0 to 127.0.0.1 (loopback): - Inference services should default to localhost-only binding. - Users can still explicitly pass host='0.0.0.0' when network access is needed. 3. Process verification for stop action: - Added _is_gr00t_process() to verify a PID belongs to a GR00T inference process before sending signals. - Host-system fallback now uses pgrep -f with the inference_service pattern instead of lsof (which matches any process on the port). --- strands_robots/tools/gr00t_inference.py | 129 ++++++++++++++++++++++-- 1 file changed, 119 insertions(+), 10 deletions(-) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 4921e64..4910b3e 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -7,6 +7,7 @@ """ import os +import re import socket import subprocess import time @@ -14,6 +15,50 @@ from strands import tool +# ───────────────────────────────────────────────────────────────────── +# Input validation helpers +# ───────────────────────────────────────────────────────────────────── + +# Characters that must never appear in values interpolated into commands. +_SHELL_META = re.compile(r"[;&|`$(){}\[\]!<>\\'\"\n\r\x00]") + +# Strict patterns for enumerable parameters. +_DATA_CONFIG_RE = re.compile(r"^[a-z][a-z0-9_]{0,63}$") +_EMBODIMENT_TAG_RE = re.compile(r"^[a-z][a-z0-9_]{0,31}$") +_CONTAINER_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}$") + + +def _validate_path(value: str, label: str) -> None: + """Reject paths containing shell metacharacters, null bytes, or traversal sequences.""" + if "\x00" in value: + raise ValueError(f"{label} must not contain null bytes") + if ".." in value.split("/"): + raise ValueError(f"{label} must not contain '..' path traversal components") + if _SHELL_META.search(value): + raise ValueError(f"{label} contains disallowed characters: {value!r}") + + +def _validate_data_config(value: str) -> None: + if not _DATA_CONFIG_RE.match(value): + raise ValueError( + f"data_config must be lowercase alphanumeric/underscore (got {value!r}). " + f"See the tool docstring for the full list of accepted configs." + ) + + +def _validate_embodiment_tag(value: str) -> None: + if not _EMBODIMENT_TAG_RE.match(value): + raise ValueError( + f"embodiment_tag must be lowercase alphanumeric/underscore (got {value!r})" + ) + + +def _validate_container_name(value: str) -> None: + if not _CONTAINER_NAME_RE.match(value): + raise ValueError( + f"container_name must match Docker naming rules (got {value!r})" + ) + @tool def gr00t_inference( @@ -24,7 +69,7 @@ def gr00t_inference( data_config: str = "fourier_gr1_arms_only", embodiment_tag: str = "gr1", denoising_steps: int = 4, - host: str = "0.0.0.0", + host: str = "127.0.0.1", container_name: str | None = None, timeout: int = 60, use_tensorrt: bool = False, @@ -112,7 +157,7 @@ def gr00t_inference( data_config: Embodiment data config name (see Data configs above). embodiment_tag: Embodiment tag for the model (e.g., ``gr1``, ``so100``). denoising_steps: Number of denoising steps for action generation (default: 4). - host: Host address to bind the service to (default: ``0.0.0.0``). + host: Host address to bind the service to (default: ``127.0.0.1``). container_name: Specific Docker container name. Auto-detected if omitted. timeout: Seconds to wait for service startup (default: 60). use_tensorrt: Enable TensorRT acceleration (default: False). @@ -180,6 +225,30 @@ def gr00t_inference( if api_token is None: api_token = os.environ.get("GROOT_API_TOKEN") + # ── Upfront input validation ────────────────────────────────────── + _validate_data_config(data_config) + _validate_embodiment_tag(embodiment_tag) + if container_name is not None: + _validate_container_name(container_name) + if checkpoint_path is not None: + _validate_path(checkpoint_path, "checkpoint_path") + _validate_path(trt_engine_path, "trt_engine_path") + + # Validate dtype values (strict allowlist) + _VALID_VIT_DTYPES = {"fp16", "fp8"} + _VALID_LLM_DTYPES = {"fp16", "nvfp4", "fp8"} + _VALID_DIT_DTYPES = {"fp16", "fp8"} + if vit_dtype not in _VALID_VIT_DTYPES: + return {"status": "error", "message": f"vit_dtype must be one of {_VALID_VIT_DTYPES}"} + if llm_dtype not in _VALID_LLM_DTYPES: + return {"status": "error", "message": f"llm_dtype must be one of {_VALID_LLM_DTYPES}"} + if dit_dtype not in _VALID_DIT_DTYPES: + return {"status": "error", "message": f"dit_dtype must be one of {_VALID_DIT_DTYPES}"} + + # Validate port range + if not (1 <= port <= 65535): + return {"status": "error", "message": "port must be between 1 and 65535"} + if action == "find_containers": return _find_gr00t_containers() elif action == "list": @@ -314,6 +383,27 @@ def _check_service_status(port: int) -> dict[str, Any]: } +def _is_gr00t_process(container_name: str, pid: str) -> bool: + """Verify that a PID inside a container belongs to a GR00T inference process. + + This prevents accidentally killing unrelated processes that happen to + be listening on the same port. + """ + try: + result = subprocess.run( + ["docker", "exec", container_name, "cat", f"/proc/{pid}/cmdline"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + cmdline = result.stdout.replace("\x00", " ") + return "inference_service" in cmdline or "gr00t" in cmdline.lower() + except Exception: + pass + return False + + def _stop_service(port: int) -> dict[str, Any]: """Stop GR00T inference service running on specific port.""" try: @@ -334,13 +424,19 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: - subprocess.run(["docker", "exec", container_name, "kill", "-TERM", pid], check=True) + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid): + subprocess.run( + ["docker", "exec", container_name, "kill", "-TERM", pid], check=True + ) time.sleep(2) result = subprocess.run( - ["docker", "exec", container_name, "pgrep", "-f", f"inference_service.py.*--port {port}"], + [ + "docker", "exec", container_name, + "pgrep", "-f", f"inference_service.py.*--port {port}", + ], capture_output=True, text=True, check=False, @@ -349,8 +445,11 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: - subprocess.run(["docker", "exec", container_name, "kill", "-KILL", pid], check=True) + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid): + subprocess.run( + ["docker", "exec", container_name, "kill", "-KILL", pid], check=True + ) return { "status": "success", @@ -362,22 +461,32 @@ def _stop_service(port: int) -> dict[str, Any]: except subprocess.CalledProcessError: continue - # Fallback: try host system - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + # Fallback: try host system — only kill processes that match inference_service + result = subprocess.run( + ["pgrep", "-f", f"inference_service.py.*--port {port}"], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: + pid = pid.strip() if pid: subprocess.run(["kill", "-TERM", pid], check=True) time.sleep(2) - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + result = subprocess.run( + ["pgrep", "-f", f"inference_service.py.*--port {port}"], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: + pid = pid.strip() if pid: subprocess.run(["kill", "-KILL", pid], check=True) From ecc839103320a2adf2bec5d1ba06a5f00a2fb838 Mon Sep 17 00:00:00 2001 From: cagataycali Date: Wed, 1 Apr 2026 19:00:23 -0400 Subject: [PATCH 2/3] style: apply ruff formatting --- strands_robots/tools/gr00t_inference.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 4910b3e..560a79a 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -48,16 +48,12 @@ def _validate_data_config(value: str) -> None: def _validate_embodiment_tag(value: str) -> None: if not _EMBODIMENT_TAG_RE.match(value): - raise ValueError( - f"embodiment_tag must be lowercase alphanumeric/underscore (got {value!r})" - ) + raise ValueError(f"embodiment_tag must be lowercase alphanumeric/underscore (got {value!r})") def _validate_container_name(value: str) -> None: if not _CONTAINER_NAME_RE.match(value): - raise ValueError( - f"container_name must match Docker naming rules (got {value!r})" - ) + raise ValueError(f"container_name must match Docker naming rules (got {value!r})") @tool @@ -426,16 +422,18 @@ def _stop_service(port: int) -> dict[str, Any]: for pid in pids: pid = pid.strip() if pid and _is_gr00t_process(container_name, pid): - subprocess.run( - ["docker", "exec", container_name, "kill", "-TERM", pid], check=True - ) + subprocess.run(["docker", "exec", container_name, "kill", "-TERM", pid], check=True) time.sleep(2) result = subprocess.run( [ - "docker", "exec", container_name, - "pgrep", "-f", f"inference_service.py.*--port {port}", + "docker", + "exec", + container_name, + "pgrep", + "-f", + f"inference_service.py.*--port {port}", ], capture_output=True, text=True, @@ -447,9 +445,7 @@ def _stop_service(port: int) -> dict[str, Any]: for pid in pids: pid = pid.strip() if pid and _is_gr00t_process(container_name, pid): - subprocess.run( - ["docker", "exec", container_name, "kill", "-KILL", pid], check=True - ) + subprocess.run(["docker", "exec", container_name, "kill", "-KILL", pid], check=True) return { "status": "success", From 4d9634a878b4348f3ef381aca8905a8ef5317c6b Mon Sep 17 00:00:00 2001 From: cagataycali Date: Fri, 3 Apr 2026 20:54:20 +0000 Subject: [PATCH 3/3] refactor: extract validate_inputs() from gr00t_inference tool Encapsulate all input validation (data_config, embodiment_tag, container_name, paths, dtypes, port range) into a single validate_inputs() function. This: 1. Keeps the tool function focused on orchestration 2. Makes validation independently testable 3. Raises ValueError consistently (no mixed return-dict errors) Tests: 15 new tests covering every validation branch. --- strands_robots/tools/gr00t_inference.py | 89 +++++--- .../groot/test_gr00t_inference_validation.py | 194 ++++++++++++++++++ 2 files changed, 251 insertions(+), 32 deletions(-) create mode 100644 tests/groot/test_gr00t_inference_validation.py diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 560a79a..10c23cd 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -27,6 +27,11 @@ _EMBODIMENT_TAG_RE = re.compile(r"^[a-z][a-z0-9_]{0,31}$") _CONTAINER_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}$") +# Allowlists for TensorRT dtype parameters. +_VALID_VIT_DTYPES = {"fp16", "fp8"} +_VALID_LLM_DTYPES = {"fp16", "nvfp4", "fp8"} +_VALID_DIT_DTYPES = {"fp16", "fp8"} + def _validate_path(value: str, label: str) -> None: """Reject paths containing shell metacharacters, null bytes, or traversal sequences.""" @@ -38,22 +43,53 @@ def _validate_path(value: str, label: str) -> None: raise ValueError(f"{label} contains disallowed characters: {value!r}") -def _validate_data_config(value: str) -> None: - if not _DATA_CONFIG_RE.match(value): +def validate_inputs( + *, + data_config: str, + embodiment_tag: str, + port: int, + vit_dtype: str, + llm_dtype: str, + dit_dtype: str, + checkpoint_path: str | None = None, + trt_engine_path: str = "gr00t_engine", + container_name: str | None = None, +) -> None: + """Validate all user-supplied parameters in one place. + + Raises ValueError for any invalid input. This centralises validation so + that the main tool function stays focused on orchestration and each + check is independently testable via this single entry-point. + """ + # Enumerable string parameters + if not _DATA_CONFIG_RE.match(data_config): raise ValueError( - f"data_config must be lowercase alphanumeric/underscore (got {value!r}). " + f"data_config must be lowercase alphanumeric/underscore (got {data_config!r}). " f"See the tool docstring for the full list of accepted configs." ) + if not _EMBODIMENT_TAG_RE.match(embodiment_tag): + raise ValueError(f"embodiment_tag must be lowercase alphanumeric/underscore (got {embodiment_tag!r})") + # Docker container name + if container_name is not None and not _CONTAINER_NAME_RE.match(container_name): + raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})") -def _validate_embodiment_tag(value: str) -> None: - if not _EMBODIMENT_TAG_RE.match(value): - raise ValueError(f"embodiment_tag must be lowercase alphanumeric/underscore (got {value!r})") + # Filesystem paths — reject shell metacharacters and traversal + if checkpoint_path is not None: + _validate_path(checkpoint_path, "checkpoint_path") + _validate_path(trt_engine_path, "trt_engine_path") + # TensorRT dtype allowlists + if vit_dtype not in _VALID_VIT_DTYPES: + raise ValueError(f"vit_dtype must be one of {_VALID_VIT_DTYPES}, got {vit_dtype!r}") + if llm_dtype not in _VALID_LLM_DTYPES: + raise ValueError(f"llm_dtype must be one of {_VALID_LLM_DTYPES}, got {llm_dtype!r}") + if dit_dtype not in _VALID_DIT_DTYPES: + raise ValueError(f"dit_dtype must be one of {_VALID_DIT_DTYPES}, got {dit_dtype!r}") -def _validate_container_name(value: str) -> None: - if not _CONTAINER_NAME_RE.match(value): - raise ValueError(f"container_name must match Docker naming rules (got {value!r})") + # Port range + if not (1 <= port <= 65535): + raise ValueError(f"port must be between 1 and 65535, got {port}") @tool @@ -221,29 +257,18 @@ def gr00t_inference( if api_token is None: api_token = os.environ.get("GROOT_API_TOKEN") - # ── Upfront input validation ────────────────────────────────────── - _validate_data_config(data_config) - _validate_embodiment_tag(embodiment_tag) - if container_name is not None: - _validate_container_name(container_name) - if checkpoint_path is not None: - _validate_path(checkpoint_path, "checkpoint_path") - _validate_path(trt_engine_path, "trt_engine_path") - - # Validate dtype values (strict allowlist) - _VALID_VIT_DTYPES = {"fp16", "fp8"} - _VALID_LLM_DTYPES = {"fp16", "nvfp4", "fp8"} - _VALID_DIT_DTYPES = {"fp16", "fp8"} - if vit_dtype not in _VALID_VIT_DTYPES: - return {"status": "error", "message": f"vit_dtype must be one of {_VALID_VIT_DTYPES}"} - if llm_dtype not in _VALID_LLM_DTYPES: - return {"status": "error", "message": f"llm_dtype must be one of {_VALID_LLM_DTYPES}"} - if dit_dtype not in _VALID_DIT_DTYPES: - return {"status": "error", "message": f"dit_dtype must be one of {_VALID_DIT_DTYPES}"} - - # Validate port range - if not (1 <= port <= 65535): - return {"status": "error", "message": "port must be between 1 and 65535"} + # ── Validate all inputs in one call ─────────────────────────────── + validate_inputs( + data_config=data_config, + embodiment_tag=embodiment_tag, + port=port, + vit_dtype=vit_dtype, + llm_dtype=llm_dtype, + dit_dtype=dit_dtype, + checkpoint_path=checkpoint_path, + trt_engine_path=trt_engine_path, + container_name=container_name, + ) if action == "find_containers": return _find_gr00t_containers() diff --git a/tests/groot/test_gr00t_inference_validation.py b/tests/groot/test_gr00t_inference_validation.py new file mode 100644 index 0000000..f059ce6 --- /dev/null +++ b/tests/groot/test_gr00t_inference_validation.py @@ -0,0 +1,194 @@ +"""Tests for gr00t_inference input validation.""" + +import pytest + +from strands_robots.tools.gr00t_inference import validate_inputs + + +class TestValidateInputs: + """Test the centralised validate_inputs() function.""" + + def test_valid_defaults(self): + """Default parameter values must pass validation.""" + validate_inputs( + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + def test_valid_with_all_optional(self): + """All optional parameters provided with valid values.""" + validate_inputs( + data_config="so100_dualcam", + embodiment_tag="so100", + port=8000, + vit_dtype="fp16", + llm_dtype="fp16", + dit_dtype="fp16", + checkpoint_path="/data/checkpoints/model", + trt_engine_path="my_engine", + container_name="isaac-gr00t-1", + ) + + # ── data_config ────────────────────────────────────────────────── + + def test_invalid_data_config_uppercase(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + data_config="UPPER", + embodiment_tag="gr1", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + def test_invalid_data_config_shell_chars(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + data_config="config;rm -rf /", + embodiment_tag="gr1", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + # ── embodiment_tag ─────────────────────────────────────────────── + + def test_invalid_embodiment_tag(self): + with pytest.raises(ValueError, match="embodiment_tag"): + validate_inputs( + data_config="so100", + embodiment_tag="BAD TAG!", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + # ── port ───────────────────────────────────────────────────────── + + def test_port_zero(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=0, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + def test_port_too_high(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=70000, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + # ── dtype allowlists ───────────────────────────────────────────── + + def test_invalid_vit_dtype(self): + with pytest.raises(ValueError, match="vit_dtype"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="int8", + llm_dtype="nvfp4", + dit_dtype="fp8", + ) + + def test_invalid_llm_dtype(self): + with pytest.raises(ValueError, match="llm_dtype"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="bf16", + dit_dtype="fp8", + ) + + def test_invalid_dit_dtype(self): + with pytest.raises(ValueError, match="dit_dtype"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="nvfp4", + ) + + # ── path validation ────────────────────────────────────────────── + + def test_checkpoint_path_traversal(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/data/../../../etc/passwd", + ) + + def test_checkpoint_path_null_byte(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/data/model\x00.bin", + ) + + def test_trt_engine_path_shell_injection(self): + with pytest.raises(ValueError, match="trt_engine_path"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + trt_engine_path="engine; rm -rf /", + ) + + # ── container_name ─────────────────────────────────────────────── + + def test_invalid_container_name(self): + with pytest.raises(ValueError, match="container_name"): + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + container_name="-invalid", + ) + + def test_container_name_none_is_ok(self): + """container_name=None should pass (auto-detect).""" + validate_inputs( + data_config="so100", + embodiment_tag="so100", + port=5555, + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + container_name=None, + )