From 5232b56501e4ce71367f4fda1ffa623f7d71aad3 Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Thu, 5 Mar 2026 11:12:43 -0600 Subject: [PATCH 01/10] first pass at parameterized shell command --- .../templates/shell_command.sh.jinja | 27 ++++ src/groundhog_hpc/templating.py | 96 +++++++++++++ tests/test_templating.py | 135 +++++++++++++++++- 3 files changed, 257 insertions(+), 1 deletion(-) diff --git a/src/groundhog_hpc/templates/shell_command.sh.jinja b/src/groundhog_hpc/templates/shell_command.sh.jinja index 3a81241..e721c51 100644 --- a/src/groundhog_hpc/templates/shell_command.sh.jinja +++ b/src/groundhog_hpc/templates/shell_command.sh.jinja @@ -1,7 +1,12 @@ set -euo pipefail +{% if parameterized %} +TASK_DIR=$(mktemp -d) +trap 'rm -rf "$TASK_DIR"' EXIT +{% else %} # Cleanup temporary files on exit (env is preserved for reuse) trap 'rm -f {{ user_script_name }}.py {{ runner_name }}.py {{ script_name }}.in {{ script_name }}.out' EXIT +{% endif %} if command -v uv &> /dev/null; then UV_BIN=$(command -v uv) @@ -51,6 +56,19 @@ export GROUNDHOG_LOG_LEVEL="${{GROUNDHOG_LOG_LEVEL:-WARNING}}" {% endraw %} {% endif %} +{% if parameterized %} +cat > "$TASK_DIR/user_script.py" << 'USER_SCRIPT_EOF' +{{ user_script_contents | escape_braces }} +USER_SCRIPT_EOF + +cat > "$TASK_DIR/runner.py" << 'RUNNER_EOF' +{{ runner_contents | escape_braces }} +RUNNER_EOF + +cat > "$TASK_DIR/payload.in" << 'PAYLOAD_EOF' +{payload} +PAYLOAD_EOF +{% else %} cat > {{ user_script_name }}.py << 'USER_SCRIPT_EOF' {{ user_script_contents | escape_braces }} USER_SCRIPT_EOF @@ -62,6 +80,7 @@ RUNNER_EOF cat > {{ script_name }}.in << 'PAYLOAD_EOF' {{ payload }} PAYLOAD_EOF +{% endif %} # Check if environment exists; create if not if [ -d "$ENV_DIR" ]; then @@ -115,7 +134,15 @@ META_EOF fi # Run using the cached environment's Python directly (bypasses uv resolution) +{% if parameterized %} +cd "$TASK_DIR" +"$ENV_DIR/bin/python" runner.py + +echo "__GROUNDHOG_RESULT__" +cat payload.out +{% else %} "$ENV_DIR/bin/python" {{ runner_name }}.py echo "__GROUNDHOG_RESULT__" cat {{ script_name }}.out +{% endif %} diff --git a/src/groundhog_hpc/templating.py b/src/groundhog_hpc/templating.py index 5d0eb16..b68267d 100644 --- a/src/groundhog_hpc/templating.py +++ b/src/groundhog_hpc/templating.py @@ -176,6 +176,102 @@ def template_shell_command(script_path: str, function_name: str, payload: str) - return shell_command_string +def template_shell_command_parameterized(script_path: str, function_name: str) -> str: + """Generate a parameterized shell command for batch execution. + + Unlike template_shell_command, the payload is NOT baked into the command. + Instead, a {payload} format placeholder is left in the shell command so a + single ShellFunction can be registered once and called with different payloads: + + shell_function(payload=serialized_payload) + + which calls cmd.format(payload=serialized_payload) before execution. + + File isolation is provided by mktemp -d per invocation so concurrent tasks + on the same node don't collide. + + Args: + script_path: Path to the user's Python script + function_name: Name of the function to execute + + Returns: + A shell command string containing a {payload} format placeholder + """ + logger.debug( + f"Templating parameterized shell command for function '{function_name}' in '{script_path}'" + ) + + with open(script_path, "r") as f_in: + user_script = f_in.read() + + metadata = read_pep723(user_script) + pep723_metadata = write_pep723(metadata) if metadata else "" + + if metadata: + env_hash = compute_env_hash(metadata) + else: + logger.warning( + "Script has no PEP 723 metadata. Environment hash based on script content; " + "environment may change unexpectedly between runs." + ) + env_hash = _script_hash_prefix(user_script) + + version_spec = get_groundhog_version_spec() + logger.debug(f"Using groundhog version spec: {version_spec}") + semver_match = re.search(r"==([0-9][^\s]*)", version_spec) + git_hash_match = re.search(r"@([a-f0-9]+)$", version_spec) + if semver_match: + groundhog_version = semver_match.group(1) + elif git_hash_match: + groundhog_version = git_hash_match.group(1) + else: + groundhog_version = _script_hash_prefix(version_spec) + + groundhog_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + templates_dir = Path(__file__).parent / "templates" + jinja_env = Environment(loader=FileSystemLoader(templates_dir)) + jinja_env.filters["escape_braces"] = escape_braces + runner_template = jinja_env.get_template("groundhog_run.py.jinja") + + runner_contents = runner_template.render( + pep723_metadata=pep723_metadata, + script_path="user_script.py", + function_name=function_name, + payload_path="payload.in", + outfile_path="payload.out", + module_name=path_to_module_name(script_path), + ) + + local_log_level = os.getenv("GROUNDHOG_LOG_LEVEL") + if local_log_level: + local_log_level = local_log_level.upper() + logger.debug(f"Propagating log level to remote: {local_log_level}") + + uv_config_toml = _serialize_uv_toml(metadata) + + shell_template = jinja_env.get_template("shell_command.sh.jinja") + shell_command_string = shell_template.render( + parameterized=True, + user_script_contents=user_script, + runner_contents=runner_contents, + version_spec=version_spec, + log_level=local_log_level, + groundhog_timestamp=groundhog_timestamp, + env_hash=env_hash, + groundhog_version=groundhog_version, + requires_python=metadata.requires_python if metadata else "", + dependencies=metadata.dependencies if metadata else [], + uv_config_toml=uv_config_toml, + ) + + logger.debug( + f"Generated parameterized shell command ({len(shell_command_string)} chars)" + ) + + return shell_command_string + + def _serialize_uv_toml(metadata: Pep723Metadata | None) -> str: """Serialize [tool.uv] settings to uv.toml format for uv pip install. diff --git a/tests/test_templating.py b/tests/test_templating.py index ac27f7f..e35a639 100644 --- a/tests/test_templating.py +++ b/tests/test_templating.py @@ -2,7 +2,10 @@ import pytest -from groundhog_hpc.templating import template_shell_command +from groundhog_hpc.templating import ( + template_shell_command, + template_shell_command_parameterized, +) class TestTemplateShellCommand: @@ -1003,3 +1006,133 @@ def compute(x): # The runner should use attrgetter for dotted paths assert "attrgetter" in result assert "MyClass.compute" in result + + +MINIMAL_SCRIPT = """\ +# /// script +# requires-python = ">=3.12" +# dependencies = [] +# /// + +import groundhog_hpc as hog + +@hog.function() +def func(): + return 42 +""" + + +class TestTemplateShellCommandParameterized: + """Tests for the parameterized shell command template.""" + + def _write_script(self, tmp_path, content=MINIMAL_SCRIPT): + p = tmp_path / "script.py" + p.write_text(content) + return str(p) + + def test_returns_a_string(self, tmp_path): + script_path = self._write_script(tmp_path) + result = template_shell_command_parameterized(script_path, "func") + assert isinstance(result, str) + assert len(result) > 0 + + def test_contains_payload_placeholder_exactly_once(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command_parameterized(script_path, "func") + assert cmd.count("{payload}") == 1 + + def test_format_with_payload_kwarg_substitutes_correctly(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command_parameterized(script_path, "func") + result = cmd.format(payload="__PICKLE__:AAAA==") + assert "__PICKLE__:AAAA==" in result + assert "{payload}" not in result + + def test_format_without_payload_kwarg_raises_key_error(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command_parameterized(script_path, "func") + with pytest.raises(KeyError): + cmd.format() + + def test_base64_payload_is_format_safe(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command_parameterized(script_path, "func") + base64_payload = "__PICKLE__:ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/==" + result = cmd.format(payload=base64_payload) + assert base64_payload in result + + def test_user_code_braces_are_escaped_before_format_call(self, tmp_path): + """Dict literals in user code survive .format(payload=...) without KeyError.""" + script_content = """\ +# /// script +# requires-python = ">=3.12" +# dependencies = [] +# /// + +import groundhog_hpc as hog + +@hog.function() +def func(): + return {"key": "value"} +""" + script_path = self._write_script(tmp_path, script_content) + cmd = template_shell_command_parameterized(script_path, "func") + # Dict braces must be doubled in cmd so .format() doesn't raise KeyError + assert '{{"key": "value"}}' in cmd + # After .format(), doubled braces collapse to single braces (dict literal preserved) + result = cmd.format(payload="test") + assert '{"key": "value"}' in result + + def test_contains_mktemp_for_file_isolation(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command_parameterized(script_path, "func") + assert "mktemp -d" in cmd + + def test_cleanup_uses_rm_rf_task_dir(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command_parameterized(script_path, "func") + assert 'rm -rf "$TASK_DIR"' in cmd + # Individual file cleanup should not appear + assert "rm -f " not in cmd + + def test_file_paths_use_fixed_names_inside_task_dir(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command_parameterized(script_path, "func") + assert "$TASK_DIR/user_script.py" in cmd + assert "$TASK_DIR/runner.py" in cmd + assert "$TASK_DIR/payload.in" in cmd + # No random UUID suffixes in paths + import re + + assert not re.search(r"\w+-[0-9a-f]{8}-[0-9a-f]{8}\.py", cmd) + + def test_runner_references_fixed_payload_path(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command_parameterized(script_path, "func") + assert "open('payload.in'" in cmd + + def test_includes_standard_uv_and_env_reuse_infrastructure(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command_parameterized(script_path, "func") + assert "ENV_HASH=" in cmd + assert "ENV_DIR=" in cmd + assert '"$UV_BIN" venv' in cmd + assert '"$UV_BIN" pip install' in cmd + assert '"$ENV_DIR/bin/python"' in cmd + + def test_different_scripts_produce_different_commands(self, tmp_path): + script1 = tmp_path / "script1.py" + script2 = tmp_path / "script2.py" + script1.write_text(MINIMAL_SCRIPT) + script2.write_text(MINIMAL_SCRIPT.replace("return 42", "return 99")) + cmd1 = template_shell_command_parameterized(str(script1), "func") + cmd2 = template_shell_command_parameterized(str(script2), "func") + assert cmd1 != cmd2 + + def test_non_parameterized_template_is_unchanged(self, tmp_path): + """Existing template_shell_command is unaffected by the template changes.""" + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func", "test_payload") + assert "test_payload" in cmd + assert "mktemp -d" not in cmd + assert "{payload}" not in cmd From 7014c9a068573dbf9ce9b3f2ba87fce5172c7e34 Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Thu, 5 Mar 2026 12:24:16 -0600 Subject: [PATCH 02/10] use parameterized shellfunction everywhere --- src/groundhog_hpc/compute.py | 33 ++--- src/groundhog_hpc/function.py | 49 +++++-- tests/conftest.py | 17 ++- tests/test_compute.py | 92 ++++++------ tests/test_function.py | 240 ++++++++++++++++++++++--------- tests/test_mark_import_safe.py | 13 +- tests/test_method.py | 19 +-- tests/test_pep723_integration.py | 50 ++++--- 8 files changed, 338 insertions(+), 175 deletions(-) diff --git a/src/groundhog_hpc/compute.py b/src/groundhog_hpc/compute.py index 8e3bb14..0a58415 100644 --- a/src/groundhog_hpc/compute.py +++ b/src/groundhog_hpc/compute.py @@ -1,7 +1,7 @@ """Globus Compute execution interface. -This module provides functions for converting user scripts into Globus Compute -ShellFunctions, registering them, and submitting them for execution on remote +This module provides functions for building Globus Compute ShellFunctions from +pre-rendered shell command strings and submitting them for execution on remote endpoints. """ @@ -13,7 +13,6 @@ from uuid import UUID from groundhog_hpc.future import GroundhogFuture -from groundhog_hpc.templating import template_shell_command logger = logging.getLogger(__name__) @@ -46,43 +45,41 @@ def _get_compute_client() -> Client: return gc.Client() -def script_to_submittable( - script_path: str, - function_name: str, - payload: str, +def build_shell_function( + shell_command: str, + name: str, walltime: int | float | None = None, ) -> ShellFunction: - """Convert a user script and function name into a Globus Compute ShellFunction. + """Create a Globus Compute ShellFunction from a pre-rendered shell command string. Args: - script_path: Path to the Python script containing the function - function_name: Name of the function to execute remotely - payload: Serialized arguments string - walltime: Optional maximum execution time in seconds for ShellFunction timeout + shell_command: The shell command string (may contain {payload} placeholder) + name: Function name used as the ShellFunction name (dots replaced with underscores) + walltime: Optional maximum execution time in seconds Returns: A ShellFunction ready to be submitted to a Globus Compute executor """ import globus_compute_sdk as gc - shell_command = template_shell_command(script_path, function_name, payload) - shell_function = gc.ShellFunction( - shell_command, name=function_name.replace(".", "_"), walltime=walltime + return gc.ShellFunction( + shell_command, name=name.replace(".", "_"), walltime=walltime ) - return shell_function def submit_to_executor( endpoint: UUID, user_endpoint_config: dict[str, Any], shell_function: ShellFunction, + payload: str, ) -> GroundhogFuture: """Submit a ShellFunction to a Globus Compute endpoint for execution. Args: endpoint: UUID of the Globus Compute endpoint user_endpoint_config: Configuration dict for the endpoint (e.g., worker_init, walltime) - shell_function: The ShellFunction to execute (with payload already templated in) + shell_function: The parameterized ShellFunction to execute + payload: Serialized arguments string, substituted into the {payload} placeholder Returns: A GroundhogFuture that will contain the deserialized result @@ -106,7 +103,7 @@ def submit_to_executor( shell_function, "__name__", getattr(shell_function, "name", "unknown") ) logger.info(f"Submitting function '{func_name}' to endpoint '{endpoint}'") - future = executor.submit(shell_function) + future = executor.submit(shell_function, payload=payload) task_id = getattr(future, "task_id", None) if task_id: logger.info(f"Task submitted with ID: {task_id}") diff --git a/src/groundhog_hpc/function.py b/src/groundhog_hpc/function.py index 97f641f..1cdda4b 100644 --- a/src/groundhog_hpc/function.py +++ b/src/groundhog_hpc/function.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, TypeVar from uuid import UUID -from groundhog_hpc.compute import script_to_submittable, submit_to_executor +from groundhog_hpc.compute import build_shell_function, submit_to_executor from groundhog_hpc.configuration.resolver import ConfigResolver from groundhog_hpc.console import display_task_status from groundhog_hpc.errors import ( @@ -29,6 +29,7 @@ ) from groundhog_hpc.future import GroundhogFuture from groundhog_hpc.serialization import deserialize_stdout, serialize +from groundhog_hpc.templating import template_shell_command_parameterized from groundhog_hpc.utils import prefix_output logger = logging.getLogger(__name__) @@ -79,12 +80,18 @@ def __init__( # ShellFunction walltime - always None here to prevent conflicts with a # 'walltime' endpoint config, but the attribute exists as an escape - # hatch if users need to set it after the function's been created + # hatch if users need to set it after the function's been created. + # NOTE: walltime must be set before the first .submit() or .local() call; + # changing it afterwards has no effect because shell_function is cached. self.walltime: int | float | None = None self._wrapped_function: FunctionType = func self._config_resolver: ConfigResolver | None = None + # Cached parameterized shell command and ShellFunction (built once, reused per instance) + self._shell_command: str | None = None + self._shell_function: ShellFunction | None = None + def __call__(self, *args: Any, **kwargs: Any) -> Any: """Execute the function locally (not remotely). @@ -177,14 +184,12 @@ def submit( f"Serializing {len(args)} args and {len(kwargs)} kwargs for '{self.name}'" ) payload = serialize((args, kwargs), use_proxy=False, proxy_threshold_mb=None) - shell_function = script_to_submittable( - self.script_path, self.name, payload, walltime=self.walltime - ) future: GroundhogFuture = submit_to_executor( UUID(endpoint), user_endpoint_config=config, - shell_function=shell_function, + shell_function=self.shell_function, + payload=payload, ) future.endpoint = endpoint future.user_endpoint_config = config @@ -260,17 +265,15 @@ def local(self, *args: Any, **kwargs: Any) -> Any: logger.debug(f"Executing function '{self.name}' in local subprocess") with prefix_output(prefix="[local]", prefix_color="blue"): - # Create ShellFunction just like we do for remote execution payload = serialize((args, kwargs), proxy_threshold_mb=1.0) - shell_function = script_to_submittable(self.script_path, self.name, payload) with tempfile.TemporaryDirectory() as tmpdir: # set sandbox dir for ShellFunction to use if "GC_TASK_SANDBOX_DIR" not in os.environ: os.environ["GC_TASK_SANDBOX_DIR"] = tmpdir - # just __call__ ShellFunction to execute the command - result = shell_function() + # call ShellFunction with payload as a parameter + result = self.shell_function(payload=payload) assert not isinstance(result, dict) if result.returncode != 0: @@ -305,6 +308,32 @@ def local(self, *args: Any, **kwargs: Any) -> Any: print(user_stdout, file=sys.stdout) return deserialized_result + @property + def shell_command(self) -> str: + """Parameterized shell command string with a {payload} placeholder. + + Generated once from the script file and cached. The same command string + is reused for all invocations of this function. + """ + if self._shell_command is None: + self._shell_command = template_shell_command_parameterized( + self.script_path, self.name + ) + return self._shell_command + + @property + def shell_function(self) -> ShellFunction: + """Cached Globus Compute ShellFunction built from the parameterized shell command. + + Created once and reused for all .submit() and .local() calls, so the + same ShellFunction object handles concurrent invocations. + """ + if self._shell_function is None: + self._shell_function = build_shell_function( + self.shell_command, self.name, walltime=self.walltime + ) + return self._shell_function + @property def script_path(self) -> str: """Get the script path for this function. diff --git a/tests/conftest.py b/tests/conftest.py index bfba954..a8924b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,12 @@ import os import sys -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from groundhog_hpc.function import Function + @pytest.fixture(scope="session", autouse=True) def configure_rich_for_ci(): @@ -162,7 +164,7 @@ def mock_submission_stack(): Provides access to all mocks and their return values for assertions. Returns dict with: - - script_to_submittable: Mock for script_to_submittable function + - shell_function_prop: PropertyMock for Function.shell_function property - submit_to_executor: Mock for submit_to_executor function - get_endpoint_schema: Mock for get_endpoint_schema function - shell_function: The mock ShellFunction instance @@ -177,9 +179,12 @@ def test_something(mock_submission_stack): mock_shell_func = MagicMock() mock_future = MagicMock() - with patch( - "groundhog_hpc.function.script_to_submittable", return_value=mock_shell_func - ) as mock_script: + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=mock_shell_func, + ) as mock_sf_prop: with patch( "groundhog_hpc.function.submit_to_executor", return_value=mock_future ) as mock_submit: @@ -187,7 +192,7 @@ def test_something(mock_submission_stack): "groundhog_hpc.compute.get_endpoint_schema", return_value={} ) as mock_schema: yield { - "script_to_submittable": mock_script, + "shell_function_prop": mock_sf_prop, "submit_to_executor": mock_submit, "get_endpoint_schema": mock_schema, "shell_function": mock_shell_func, diff --git a/tests/test_compute.py b/tests/test_compute.py index 7c86d3a..8863dfc 100644 --- a/tests/test_compute.py +++ b/tests/test_compute.py @@ -5,56 +5,40 @@ from uuid import UUID from groundhog_hpc.compute import ( - script_to_submittable, + build_shell_function, submit_to_executor, ) -class TestScriptToSubmittable: - """Test the script_to_submittable function.""" +class TestBuildShellFunction: + """Test the build_shell_function helper.""" - def test_creates_shell_function(self, tmp_path): - """Test that script_to_submittable creates a ShellFunction.""" - script_path = tmp_path / "test.py" - script_path.write_text("# test") - payload = "__PICKLE__:test_payload" + def test_creates_shell_function_with_correct_name(self): + """Test that dots in function name are replaced with underscores.""" + with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_sf: + build_shell_function("echo test", "my.module.func") + mock_sf.assert_called_once_with( + "echo test", name="my_module_func", walltime=None + ) - with patch("groundhog_hpc.compute.template_shell_command") as mock_template: - mock_template.return_value = "echo test" - with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_shell_func: - _result = script_to_submittable( - str(script_path), "my_function", payload - ) - - # Verify template was called with correct args - mock_template.assert_called_once_with( - str(script_path), "my_function", payload - ) - - # Verify ShellFunction was created with correct args - mock_shell_func.assert_called_once_with( - "echo test", name="my_function", walltime=None - ) - - def test_uses_function_name_as_shell_function_name(self, tmp_path): - """Test that function name is used as the ShellFunction name.""" - script_path = tmp_path / "test.py" - script_path.write_text("# test") - payload = "__PICKLE__:test_payload" + def test_passes_walltime(self): + """Test that walltime is forwarded to ShellFunction.""" + with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_sf: + build_shell_function("echo test", "func", walltime=300) + assert mock_sf.call_args[1]["walltime"] == 300 - with patch("groundhog_hpc.compute.template_shell_command"): - with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_shell_func: - script_to_submittable(str(script_path), "custom_func_name", payload) - - # Verify name was passed - assert mock_shell_func.call_args[1]["name"] == "custom_func_name" + def test_default_walltime_is_none(self): + """Test that walltime defaults to None.""" + with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_sf: + build_shell_function("echo test", "func") + assert mock_sf.call_args[1]["walltime"] is None class TestSubmitToExecutor: """Test the submit_to_executor function.""" def test_creates_executor_and_submits(self, mock_endpoint_uuid, mock_executor): - """Test that Executor is created and submit is called.""" + """Test that Executor is created and submit is called with payload.""" mock_shell_func = MagicMock() mock_future = Future() mock_executor.submit.return_value = mock_future @@ -64,7 +48,10 @@ def test_creates_executor_and_submits(self, mock_endpoint_uuid, mock_executor): with patch("groundhog_hpc.compute.gc.Executor", return_value=mock_executor): with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=None): result = submit_to_executor( - UUID(mock_endpoint_uuid), user_config, mock_shell_func + UUID(mock_endpoint_uuid), + user_config, + mock_shell_func, + payload="test_payload", ) # Verify Executor was created with correct endpoint and config @@ -74,12 +61,30 @@ def test_creates_executor_and_submits(self, mock_endpoint_uuid, mock_executor): UUID(mock_endpoint_uuid), user_endpoint_config=user_config ) - # Verify submit was called with shell function (payload already baked in) - mock_executor.submit.assert_called_once_with(mock_shell_func) + # Verify submit was called with shell function and payload + mock_executor.submit.assert_called_once_with( + mock_shell_func, payload="test_payload" + ) # Result should be a Future (the deserializing one, not the original) assert isinstance(result, Future) + def test_passes_payload_to_executor_submit(self, mock_endpoint_uuid, mock_executor): + """Test that payload is forwarded to executor.submit as keyword argument.""" + mock_shell_func = MagicMock() + mock_future = Future() + mock_executor.submit.return_value = mock_future + + with patch("groundhog_hpc.compute.gc.Executor", return_value=mock_executor): + with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=None): + submit_to_executor( + UUID(mock_endpoint_uuid), {}, mock_shell_func, payload="abc123" + ) + + mock_executor.submit.assert_called_once_with( + mock_shell_func, payload="abc123" + ) + def test_returns_deserializing_future(self, mock_endpoint_uuid, mock_executor): """Test that a deserializing future is returned, not the original.""" mock_shell_func = MagicMock() @@ -89,7 +94,7 @@ def test_returns_deserializing_future(self, mock_endpoint_uuid, mock_executor): with patch("groundhog_hpc.compute.gc.Executor", return_value=mock_executor): with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=None): result = submit_to_executor( - UUID(mock_endpoint_uuid), {}, mock_shell_func + UUID(mock_endpoint_uuid), {}, mock_shell_func, payload="test" ) # Should return a different future than the one from executor.submit @@ -109,7 +114,10 @@ def test_walltime_in_config_passed_to_executor( with patch("groundhog_hpc.compute.gc.Executor", return_value=mock_executor): with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=None): submit_to_executor( - UUID(mock_endpoint_uuid), user_config, mock_shell_func + UUID(mock_endpoint_uuid), + user_config, + mock_shell_func, + payload="test", ) # Verify walltime was NOT extracted from config - it should still be present diff --git a/tests/test_function.py b/tests/test_function.py index 76dab4e..f8653d5 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,7 +1,7 @@ """Tests for the Function class.""" import os -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -105,12 +105,11 @@ def test_script_path_raises_when_uninspectable(self): ): _ = func.script_path - def test_submit_creates_shell_function(self, tmp_path, mock_endpoint_uuid): - """Test that submit creates a shell function using script_to_submittable.""" + def test_submit_uses_shell_function_property(self, tmp_path, mock_endpoint_uuid): + """Test that submit uses the cached shell_function property.""" script_path = tmp_path / "test_script.py" - script_content = "# test script content" - script_path.write_text(script_content) + script_path.write_text("# test script content") func = Function(dummy_function, endpoint=mock_endpoint_uuid) func._script_path = str(script_path) @@ -118,26 +117,24 @@ def test_submit_creates_shell_function(self, tmp_path, mock_endpoint_uuid): mock_shell_func = MagicMock() mock_future = MagicMock() - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, - ) as mock_script_to_submittable: + ): with patch( "groundhog_hpc.function.submit_to_executor", return_value=mock_future, - ): + ) as mock_submit: with patch( "groundhog_hpc.compute.get_endpoint_schema", return_value={} ): func.submit() - # Verify script_to_submittable was called with correct arguments - mock_script_to_submittable.assert_called_once() - call_args = mock_script_to_submittable.call_args[0] - assert call_args[0] == str(script_path) - assert ( - call_args[1] == "simple_function" - ) # dummy_function is an alias to simple_function + # Verify submit_to_executor was called with the cached shell_function + mock_submit.assert_called_once() + assert mock_submit.call_args[1]["shell_function"] is mock_shell_func class TestSubmitMethod: @@ -188,9 +185,9 @@ def test_submit_serializes_arguments( call_args = mock_serialize.call_args[0][0] assert call_args == ((1, 2), {"kwarg1": "value1"}) - # Verify script_to_submittable received the serialized payload + # Verify submit_to_executor received the serialized payload assert ( - mock_submission_stack["script_to_submittable"].call_args[0][2] + mock_submission_stack["submit_to_executor"].call_args[1]["payload"] == "serialized_payload" ) @@ -265,21 +262,6 @@ def test_callsite_walltime_goes_to_config( config = mock_submit.call_args[1]["user_endpoint_config"] assert config["walltime"] == 120 - def test_function_walltime_sets_shellfunction_walltime( - self, function_with_script, mock_submission_stack - ): - """Test that Function.walltime attribute sets ShellFunction walltime (escape hatch).""" - # Create function and manually set walltime (escape hatch) - func = function_with_script() - func.walltime = 120 - - func.submit() - - # Verify script_to_submittable was called with walltime parameter - mock_script_to_submittable = mock_submission_stack["script_to_submittable"] - call_args = mock_script_to_submittable.call_args - assert call_args[1]["walltime"] == 120 - def test_callsite_user_config_overrides_default( self, function_with_script, mock_submission_stack ): @@ -390,8 +372,10 @@ def add(a, b): # Create mock result shell_func, result = mock_local_result(stdout='{"result": 5}') - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=shell_func, ): with patch( @@ -415,8 +399,10 @@ def test_local_serializes_arguments(self, tmp_path, mock_local_result): with patch( "groundhog_hpc.function.serialize", return_value="serialized" ) as mock_serialize: - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=shell_func, ): with patch( @@ -456,8 +442,10 @@ def test_local_runs_in_temporary_directory(self, tmp_path): if "GC_TASK_SANDBOX_DIR" in os.environ: del os.environ["GC_TASK_SANDBOX_DIR"] - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_function, ): with patch( @@ -496,41 +484,35 @@ def local_func(): with pytest.raises(ValueError, match="Could not determine script path"): func.local() - def test_local_uses_script_to_submittable(self, tmp_path, mock_local_result): - """Test that local() uses script_to_submittable to create ShellFunction.""" + def test_local_uses_shell_function_property(self, tmp_path, mock_local_result): + """Test that local() uses the cached shell_function property.""" script_path = tmp_path / "test_local.py" script_path.write_text("# test") func = Function(dummy_function) func._script_path = str(script_path) - # Set the import flag to allow .local() call - import sys - - test_module = sys.modules.get("tests.test_fixtures") - test_module.__groundhog_imported__ = True - shell_func, result = mock_local_result(stdout="result") - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=shell_func, - ) as mock_script_to_submittable: + ) as mock_sf_prop: with patch( "groundhog_hpc.function.deserialize_stdout", return_value=(None, "result"), ): func.local() - # Verify script_to_submittable was called with script path, function name, and payload - assert mock_script_to_submittable.call_count == 1 - call_args = mock_script_to_submittable.call_args[0] - assert call_args[0] == str(script_path) - assert call_args[1] == "simple_function" - assert len(call_args) == 3 # script_path, function_name, payload + # Verify shell_function property was accessed (not script_to_submittable) + mock_sf_prop.assert_called() - def test_local_calls_shell_function(self, tmp_path, mock_local_result): - """Test that local() calls the ShellFunction returned by script_to_submittable.""" + def test_local_calls_shell_function_with_payload_kwarg( + self, tmp_path, mock_local_result + ): + """Test that local() calls shell_function(payload=...) not shell_function().""" script_path = tmp_path / "test_local.py" script_path.write_text("# test") @@ -539,8 +521,10 @@ def test_local_calls_shell_function(self, tmp_path, mock_local_result): shell_func, result = mock_local_result(stdout="result") - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=shell_func, ): with patch("groundhog_hpc.function.serialize", return_value="ABC123"): @@ -550,10 +534,9 @@ def test_local_calls_shell_function(self, tmp_path, mock_local_result): ): func.local() - # Verify ShellFunction was called (invoked via __call__) + # Verify ShellFunction was called with payload as keyword argument shell_func.assert_called_once() - # Verify it was called with no arguments (ShellFunction handles its own execution) - assert shell_func.call_args[0] == () + assert shell_func.call_args[1]["payload"] == "ABC123" def test_local_infers_script_path_from_function(self, tmp_path): """Test that local() can infer script path from function's source file.""" @@ -582,8 +565,10 @@ def my_function(): with patch( "groundhog_hpc.function.inspect.getfile", return_value=str(script_path) ): - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_function, ): with patch( @@ -594,6 +579,120 @@ def my_function(): assert result == 42 +class TestShellCommandProperty: + """Test the shell_command lazy-cached property.""" + + def test_calls_template_with_script_path_and_name(self, tmp_path): + """shell_command calls template_shell_command_parameterized with correct args.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + + with patch( + "groundhog_hpc.function.template_shell_command_parameterized", + return_value="parameterized_cmd", + ) as mock_template: + result = func.shell_command + + mock_template.assert_called_once_with(func._script_path, func.name) + assert result == "parameterized_cmd" + + def test_caches_result_on_second_access(self, tmp_path): + """shell_command returns cached value without re-calling the template.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + + with patch( + "groundhog_hpc.function.template_shell_command_parameterized", + return_value="cmd1", + ) as mock_template: + first = func.shell_command + second = func.shell_command + + mock_template.assert_called_once() + assert first == second == "cmd1" + + +class TestShellFunctionProperty: + """Test the shell_function lazy-cached property.""" + + def test_calls_build_shell_function_with_correct_args(self, tmp_path): + """shell_function calls build_shell_function with shell_command, name, walltime.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + func.walltime = 120 + + mock_sf = MagicMock() + + with patch( + "groundhog_hpc.function.template_shell_command_parameterized", + return_value="paramcmd", + ): + with patch( + "groundhog_hpc.function.build_shell_function", + return_value=mock_sf, + ) as mock_build: + result = func.shell_function + + mock_build.assert_called_once_with("paramcmd", func.name, walltime=120) + assert result is mock_sf + + def test_caches_result_on_second_access(self, tmp_path): + """shell_function returns cached value without re-calling build_shell_function.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + + mock_sf = MagicMock() + + with patch( + "groundhog_hpc.function.template_shell_command_parameterized", + return_value="cmd", + ): + with patch( + "groundhog_hpc.function.build_shell_function", + return_value=mock_sf, + ) as mock_build: + first = func.shell_function + second = func.shell_function + + mock_build.assert_called_once() + assert first is second is mock_sf + + def test_default_walltime_is_none(self, tmp_path): + """shell_function passes walltime=None when not set.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + + with patch( + "groundhog_hpc.function.template_shell_command_parameterized", + return_value="cmd", + ): + with patch( + "groundhog_hpc.function.build_shell_function", + return_value=MagicMock(), + ) as mock_build: + func.shell_function + + assert mock_build.call_args[1]["walltime"] is None + + def test_walltime_flows_into_shell_function(self, tmp_path): + """walltime set before first access is used by build_shell_function.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + func.walltime = 300 + + with patch( + "groundhog_hpc.function.template_shell_command_parameterized", + return_value="cmd", + ): + with patch( + "groundhog_hpc.function.build_shell_function", + return_value=MagicMock(), + ) as mock_build: + func.shell_function + + assert mock_build.call_args[1]["walltime"] == 300 + + class TestLocalAlwaysUsesSubprocess: """Test that .local() always uses subprocess (no direct call fallback).""" @@ -626,11 +725,13 @@ def test_func(x): shell_func, result = mock_local_result(stdout="84") - # Mock script_to_submittable to verify subprocess is used - with patch( - "groundhog_hpc.function.script_to_submittable", + # Patch shell_function property to verify subprocess is used + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=shell_func, - ) as mock_script_to_submittable: + ): with patch( "groundhog_hpc.function.deserialize_stdout", return_value=(None, 84) ): @@ -638,5 +739,4 @@ def test_func(x): # Should always use subprocess (ShellFunction) assert result_value == 84 - mock_script_to_submittable.assert_called_once() shell_func.assert_called_once() diff --git a/tests/test_mark_import_safe.py b/tests/test_mark_import_safe.py index 48b94a6..1259c22 100644 --- a/tests/test_mark_import_safe.py +++ b/tests/test_mark_import_safe.py @@ -2,7 +2,7 @@ import sys import types -from unittest.mock import Mock, patch +from unittest.mock import Mock, PropertyMock, patch import pytest @@ -174,7 +174,7 @@ def my_func(): # Verify flag is set assert module.__groundhog_imported__ is True - # Mock script_to_submittable to avoid actual subprocess execution + # Mock shell_function property to avoid actual subprocess execution mock_shell_func = Mock() mock_result = Mock() mock_result.returncode = 0 @@ -182,8 +182,13 @@ def my_func(): mock_result.stderr = "" mock_shell_func.return_value = mock_result - with patch( - "groundhog_hpc.function.script_to_submittable", return_value=mock_shell_func + from groundhog_hpc.function import Function + + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=mock_shell_func, ): # Now .local() should work (won't raise ModuleImportError) result = module.my_func.local() diff --git a/tests/test_method.py b/tests/test_method.py index 28b321b..ac9c33d 100644 --- a/tests/test_method.py +++ b/tests/test_method.py @@ -1,6 +1,6 @@ """Tests for the Method class.""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch from groundhog_hpc.function import Function, Method @@ -93,19 +93,22 @@ def compute(x): mock_shell_func = MagicMock() mock_future = MagicMock() - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, - ) as mock_script_to_submittable: + ): with patch( "groundhog_hpc.function.submit_to_executor", return_value=mock_future, - ): + ) as mock_submit: with patch( "groundhog_hpc.compute.get_endpoint_schema", return_value={} ): method.submit(5) - # Verify qualname was passed correctly - call_args = mock_script_to_submittable.call_args[0] - assert call_args[1] == "MyClass.compute" + # Verify the function name (qualname) is used — visible in the shell_function property name + assert method.name == "MyClass.compute" + # Verify submit was called (method uses the same submit path as Function) + mock_submit.assert_called_once() diff --git a/tests/test_pep723_integration.py b/tests/test_pep723_integration.py index 05dd381..3807ecf 100644 --- a/tests/test_pep723_integration.py +++ b/tests/test_pep723_integration.py @@ -7,7 +7,7 @@ import os import tempfile from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch from uuid import UUID import pytest @@ -69,8 +69,10 @@ def test_func(): "qos": {"type": "string"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -152,8 +154,10 @@ def test_func(): "partition": {"type": "string"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -232,8 +236,10 @@ def test_func(): "cores": {"type": "integer"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -304,8 +310,10 @@ def test_func(): "qos": {"type": "string"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -372,8 +380,10 @@ def test_func(): # Mock schema that includes worker_init mock_schema = {"properties": {"worker_init": {"type": "string"}}} - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -455,8 +465,10 @@ def test_func(): } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -482,8 +494,10 @@ def test_func(): # Reset mock mock_submit.reset_mock() - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -561,8 +575,10 @@ def test_func(): "qos": {"type": "string"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( From 5d536e8517622eb8f35dcf94c9e5eed591d68959 Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:42:21 -0600 Subject: [PATCH 03/10] clean up unused template logic --- src/groundhog_hpc/function.py | 6 +- .../templates/shell_command.sh.jinja | 26 ---- src/groundhog_hpc/templating.py | 135 +----------------- tests/test_function.py | 14 +- tests/test_templating.py | 120 +++++++--------- 5 files changed, 66 insertions(+), 235 deletions(-) diff --git a/src/groundhog_hpc/function.py b/src/groundhog_hpc/function.py index 1cdda4b..84e3330 100644 --- a/src/groundhog_hpc/function.py +++ b/src/groundhog_hpc/function.py @@ -29,7 +29,7 @@ ) from groundhog_hpc.future import GroundhogFuture from groundhog_hpc.serialization import deserialize_stdout, serialize -from groundhog_hpc.templating import template_shell_command_parameterized +from groundhog_hpc.templating import template_shell_command from groundhog_hpc.utils import prefix_output logger = logging.getLogger(__name__) @@ -316,9 +316,7 @@ def shell_command(self) -> str: is reused for all invocations of this function. """ if self._shell_command is None: - self._shell_command = template_shell_command_parameterized( - self.script_path, self.name - ) + self._shell_command = template_shell_command(self.script_path, self.name) return self._shell_command @property diff --git a/src/groundhog_hpc/templates/shell_command.sh.jinja b/src/groundhog_hpc/templates/shell_command.sh.jinja index e721c51..c3e08e8 100644 --- a/src/groundhog_hpc/templates/shell_command.sh.jinja +++ b/src/groundhog_hpc/templates/shell_command.sh.jinja @@ -1,12 +1,7 @@ set -euo pipefail -{% if parameterized %} TASK_DIR=$(mktemp -d) trap 'rm -rf "$TASK_DIR"' EXIT -{% else %} -# Cleanup temporary files on exit (env is preserved for reuse) -trap 'rm -f {{ user_script_name }}.py {{ runner_name }}.py {{ script_name }}.in {{ script_name }}.out' EXIT -{% endif %} if command -v uv &> /dev/null; then UV_BIN=$(command -v uv) @@ -56,7 +51,6 @@ export GROUNDHOG_LOG_LEVEL="${{GROUNDHOG_LOG_LEVEL:-WARNING}}" {% endraw %} {% endif %} -{% if parameterized %} cat > "$TASK_DIR/user_script.py" << 'USER_SCRIPT_EOF' {{ user_script_contents | escape_braces }} USER_SCRIPT_EOF @@ -68,19 +62,6 @@ RUNNER_EOF cat > "$TASK_DIR/payload.in" << 'PAYLOAD_EOF' {payload} PAYLOAD_EOF -{% else %} -cat > {{ user_script_name }}.py << 'USER_SCRIPT_EOF' -{{ user_script_contents | escape_braces }} -USER_SCRIPT_EOF - -cat > {{ runner_name }}.py << 'RUNNER_EOF' -{{ runner_contents | escape_braces }} -RUNNER_EOF - -cat > {{ script_name }}.in << 'PAYLOAD_EOF' -{{ payload }} -PAYLOAD_EOF -{% endif %} # Check if environment exists; create if not if [ -d "$ENV_DIR" ]; then @@ -134,15 +115,8 @@ META_EOF fi # Run using the cached environment's Python directly (bypasses uv resolution) -{% if parameterized %} cd "$TASK_DIR" "$ENV_DIR/bin/python" runner.py echo "__GROUNDHOG_RESULT__" cat payload.out -{% else %} -"$ENV_DIR/bin/python" {{ runner_name }}.py - -echo "__GROUNDHOG_RESULT__" -cat {{ script_name }}.out -{% endif %} diff --git a/src/groundhog_hpc/templating.py b/src/groundhog_hpc/templating.py index b68267d..020483c 100644 --- a/src/groundhog_hpc/templating.py +++ b/src/groundhog_hpc/templating.py @@ -11,7 +11,6 @@ import logging import os import re -import uuid from datetime import datetime, timezone from hashlib import sha1 from pathlib import Path @@ -63,125 +62,12 @@ def compute_env_hash(metadata: Pep723Metadata) -> str: return sha1(canonical.encode("utf-8")).hexdigest()[:8] -def template_shell_command(script_path: str, function_name: str, payload: str) -> str: - """Generate a shell command to execute a user function on a remote endpoint. +def template_shell_command(script_path: str, function_name: str) -> str: + """Generate a parameterized shell command for remote execution. - The generated shell command: - - Creates a runner script that imports the user script as a module - - Writes the user script to a file (unmodified) - - Sets up input/output files for serialized data - - Executes the runner with uv for dependency management - - Args: - script_path: Path to the user's Python script - function_name: Name of the function to execute - payload: Serialized arguments string - - Returns: - A fully-formed shell command string ready to be executed via Globus - Compute or local subprocess - """ - logger.debug( - f"Templating shell command for function '{function_name}' in script '{script_path}'" - ) - - with open(script_path, "r") as f_in: - user_script = f_in.read() - - # Extract PEP 723 metadata for the runner - metadata = read_pep723(user_script) - pep723_metadata = write_pep723(metadata) if metadata else "" - - if metadata: - env_hash = compute_env_hash(metadata) - else: - logger.warning( - "Script has no PEP 723 metadata. Environment hash based on script content; " - "environment may change unexpectedly between runs." - ) - env_hash = _script_hash_prefix(user_script) - - script_hash = _script_hash_prefix(user_script) - script_basename = _extract_script_basename(script_path) - random_suffix = uuid.uuid4().hex[:8] - script_name = f"{script_basename}-{script_hash}-{random_suffix}" - - # Generate names for the user script and runner - user_script_name = script_name - runner_name = f"{script_name}_runner" - user_script_path_remote = f"{user_script_name}.py" - payload_path = f"{script_name}.in" - outfile_path = f"{script_name}.out" - - version_spec = get_groundhog_version_spec() - logger.debug(f"Using groundhog version spec: {version_spec}") - semver_match = re.search(r"==([0-9][^\s]*)", version_spec) - git_hash_match = re.search(r"@([a-f0-9]+)$", version_spec) - if semver_match: - groundhog_version = semver_match.group(1) - elif git_hash_match: - groundhog_version = git_hash_match.group(1) - else: - groundhog_version = _script_hash_prefix(version_spec) - - # Generate timestamp for groundhog-hpc exclude-newer override - # This allows groundhog to bypass user's exclude-newer restrictions - groundhog_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - - # Load runner template - templates_dir = Path(__file__).parent / "templates" - jinja_env = Environment(loader=FileSystemLoader(templates_dir)) - jinja_env.filters["escape_braces"] = escape_braces - runner_template = jinja_env.get_template("groundhog_run.py.jinja") - - # Render runner script - runner_contents = runner_template.render( - pep723_metadata=pep723_metadata, - script_path=user_script_path_remote, - function_name=function_name, - payload_path=payload_path, - outfile_path=outfile_path, - module_name=path_to_module_name(script_path), - ) - - # Read local log level (None if not set) - local_log_level = os.getenv("GROUNDHOG_LOG_LEVEL") - if local_log_level: - local_log_level = local_log_level.upper() - logger.debug(f"Propagating log level to remote: {local_log_level}") - - uv_config_toml = _serialize_uv_toml(metadata) - - # Render shell command - shell_template = jinja_env.get_template("shell_command.sh.jinja") - shell_command_string = shell_template.render( - user_script_name=user_script_name, - user_script_contents=user_script, - runner_name=runner_name, - runner_contents=runner_contents, - script_name=script_name, - version_spec=version_spec, - payload=payload, - log_level=local_log_level, - groundhog_timestamp=groundhog_timestamp, - env_hash=env_hash, - groundhog_version=groundhog_version, - requires_python=metadata.requires_python if metadata else "", - dependencies=metadata.dependencies if metadata else [], - uv_config_toml=uv_config_toml, - ) - - logger.debug(f"Generated shell command ({len(shell_command_string)} chars)") - - return shell_command_string - - -def template_shell_command_parameterized(script_path: str, function_name: str) -> str: - """Generate a parameterized shell command for batch execution. - - Unlike template_shell_command, the payload is NOT baked into the command. - Instead, a {payload} format placeholder is left in the shell command so a - single ShellFunction can be registered once and called with different payloads: + The payload is NOT baked into the command. Instead, a {payload} format + placeholder is left so a single ShellFunction can be reused for all + invocations of the same function: shell_function(payload=serialized_payload) @@ -198,7 +84,7 @@ def template_shell_command_parameterized(script_path: str, function_name: str) - A shell command string containing a {payload} format placeholder """ logger.debug( - f"Templating parameterized shell command for function '{function_name}' in '{script_path}'" + f"Templating shell command for function '{function_name}' in '{script_path}'" ) with open(script_path, "r") as f_in: @@ -252,7 +138,6 @@ def template_shell_command_parameterized(script_path: str, function_name: str) - shell_template = jinja_env.get_template("shell_command.sh.jinja") shell_command_string = shell_template.render( - parameterized=True, user_script_contents=user_script, runner_contents=runner_contents, version_spec=version_spec, @@ -265,9 +150,7 @@ def template_shell_command_parameterized(script_path: str, function_name: str) - uv_config_toml=uv_config_toml, ) - logger.debug( - f"Generated parameterized shell command ({len(shell_command_string)} chars)" - ) + logger.debug(f"Generated shell command ({len(shell_command_string)} chars)") return shell_command_string @@ -290,7 +173,3 @@ def _serialize_uv_toml(metadata: Pep723Metadata | None) -> str: def _script_hash_prefix(contents: str, length: int = 8) -> str: return str(sha1(bytes(contents, "utf-8")).hexdigest()[:length]) - - -def _extract_script_basename(script_path: str) -> str: - return Path(script_path).stem diff --git a/tests/test_function.py b/tests/test_function.py index f8653d5..8b9d036 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -583,12 +583,12 @@ class TestShellCommandProperty: """Test the shell_command lazy-cached property.""" def test_calls_template_with_script_path_and_name(self, tmp_path): - """shell_command calls template_shell_command_parameterized with correct args.""" + """shell_command calls template_shell_command with correct args.""" func = Function(dummy_function) func._script_path = str(tmp_path / "fake.py") with patch( - "groundhog_hpc.function.template_shell_command_parameterized", + "groundhog_hpc.function.template_shell_command", return_value="parameterized_cmd", ) as mock_template: result = func.shell_command @@ -602,7 +602,7 @@ def test_caches_result_on_second_access(self, tmp_path): func._script_path = str(tmp_path / "fake.py") with patch( - "groundhog_hpc.function.template_shell_command_parameterized", + "groundhog_hpc.function.template_shell_command", return_value="cmd1", ) as mock_template: first = func.shell_command @@ -624,7 +624,7 @@ def test_calls_build_shell_function_with_correct_args(self, tmp_path): mock_sf = MagicMock() with patch( - "groundhog_hpc.function.template_shell_command_parameterized", + "groundhog_hpc.function.template_shell_command", return_value="paramcmd", ): with patch( @@ -644,7 +644,7 @@ def test_caches_result_on_second_access(self, tmp_path): mock_sf = MagicMock() with patch( - "groundhog_hpc.function.template_shell_command_parameterized", + "groundhog_hpc.function.template_shell_command", return_value="cmd", ): with patch( @@ -663,7 +663,7 @@ def test_default_walltime_is_none(self, tmp_path): func._script_path = str(tmp_path / "fake.py") with patch( - "groundhog_hpc.function.template_shell_command_parameterized", + "groundhog_hpc.function.template_shell_command", return_value="cmd", ): with patch( @@ -681,7 +681,7 @@ def test_walltime_flows_into_shell_function(self, tmp_path): func.walltime = 300 with patch( - "groundhog_hpc.function.template_shell_command_parameterized", + "groundhog_hpc.function.template_shell_command", return_value="cmd", ): with patch( diff --git a/tests/test_templating.py b/tests/test_templating.py index e35a639..aeccb66 100644 --- a/tests/test_templating.py +++ b/tests/test_templating.py @@ -2,10 +2,7 @@ import pytest -from groundhog_hpc.templating import ( - template_shell_command, - template_shell_command_parameterized, -) +from groundhog_hpc.templating import template_shell_command class TestTemplateShellCommand: @@ -31,7 +28,7 @@ def foo(): script_path.write_text(script_content) # Should not raise any errors - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") assert isinstance(shell_command, str) # User script should be included as-is (with __main__ block) assert 'if __name__ == "__main__":' in shell_command @@ -52,13 +49,14 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") - # Should create both user script and runner - assert "_runner.py" in shell_command + # Should create runner in TASK_DIR + assert "$TASK_DIR/runner.py" in shell_command # Runner should import the user script assert ( - 'module = import_user_script("test_script", "test_script-' in shell_command + 'module = import_user_script("test_script", "user_script.py")' + in shell_command ) # Runner should invoke the target function using attrgetter assert 'func = attrgetter("foo")(module)' in shell_command @@ -79,7 +77,7 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") # Runner should contain the metadata assert 'requires-python = ">=3.12"' in shell_command @@ -108,7 +106,7 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") # Runner should contain the [tool.uv] section assert "[tool.uv]" in shell_command @@ -133,7 +131,7 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") # Should NOT contain --managed-python (it's now in [tool.uv]) assert "--managed-python" not in shell_command @@ -157,7 +155,7 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") # Check that it's a non-empty string assert isinstance(shell_command, str) @@ -178,9 +176,7 @@ def test_func(): """ script_path.write_text(script_content) - shell_command = template_shell_command( - str(script_path), "test_func", "test_payload" - ) + shell_command = template_shell_command(str(script_path), "test_func") # Should include the basename assert "my_script" in shell_command @@ -200,9 +196,7 @@ def my_function(): """ script_path.write_text(script_content) - shell_command = template_shell_command( - str(script_path), "my_function", "test_payload" - ) + shell_command = template_shell_command(str(script_path), "my_function") assert "my_function" in shell_command @@ -221,11 +215,10 @@ def func(): """ script_path.write_text(script_content) - test_payload = "MY_TEST_PAYLOAD_12345" - shell_command = template_shell_command(str(script_path), "func", test_payload) + shell_command = template_shell_command(str(script_path), "func") - # Payload should be rendered directly in the command (via Jinja2) - assert test_payload in shell_command + # Command should contain the {payload} placeholder (filled in at call time) + assert "{payload}" in shell_command def test_includes_uv_commands(self, tmp_path): """Test that the shell command uses uv for env creation.""" @@ -242,7 +235,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "test_payload") + shell_command = template_shell_command(str(script_path), "func") # Check for uv installation assert "uv.find_uv_bin()" in shell_command @@ -265,9 +258,7 @@ def dict_func(): """ script_path.write_text(script_content) - shell_command = template_shell_command( - str(script_path), "dict_func", "test_payload" - ) + shell_command = template_shell_command(str(script_path), "dict_func") # Curly braces in user code should be doubled (escaped via Jinja2 filter) # This is needed because Globus Compute's ShellFunction calls .format() @@ -296,20 +287,18 @@ def use_torch(): """ script_path.write_text(script_content) - shell_command = template_shell_command( - str(script_path), "use_torch", "test_payload" - ) + shell_command = template_shell_command(str(script_path), "use_torch") # Simulate what Globus Compute's ShellFunction does: - # It calls .format() on the command (without any kwargs) + # It calls .format(payload=...) on the command try: # This should not raise KeyError if curly braces are properly escaped - formatted = shell_command.format() + formatted = shell_command.format(payload="test_payload") # After .format(), the doubled braces should become single braces assert '{"torch"' in formatted except KeyError as e: pytest.fail( - f"shell_command.format() raised KeyError: {e}. " + f"shell_command.format(payload=...) raised KeyError: {e}. " "This means curly braces in user code are not properly escaped!" ) @@ -341,8 +330,8 @@ def func2(): script1_path.write_text(script1_content) script2_path.write_text(script2_content) - command1 = template_shell_command(str(script1_path), "func1", "test_payload") - command2 = template_shell_command(str(script2_path), "func2", "test_payload") + command1 = template_shell_command(str(script1_path), "func1") + command2 = template_shell_command(str(script2_path), "func2") # Extract the script names (format: basename-hash) # They should have different hashes since content differs @@ -370,7 +359,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "test_payload") + shell_command = template_shell_command(str(script_path), "func") # Should include the package-specific exclude-newer override assert "--exclude-newer-package groundhog-hpc=" in shell_command @@ -562,7 +551,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "ENV_HASH=" in shell_command @@ -582,7 +571,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "groundhog-envs" in shell_command assert "ENV_DIR=" in shell_command @@ -603,7 +592,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert 'if [ -d "$ENV_DIR" ]' in shell_command assert '"$UV_BIN" venv' in shell_command @@ -625,7 +614,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert '"$ENV_DIR/bin/python"' in shell_command assert '"$UV_BIN" run' not in shell_command @@ -646,7 +635,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "groundhog-meta.json" in shell_command assert '"requires_python":' in shell_command @@ -668,7 +657,7 @@ def func(): script_path.write_text(script_content) with caplog.at_level(logging.WARNING): - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "ENV_HASH=" in shell_command assert any( @@ -851,7 +840,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert '"$ENV_DIR/uv.toml"' in shell_command assert 'exclude-newer = "2025-01-01T00:00:00Z"' in shell_command @@ -875,7 +864,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert '--config-file "$ENV_DIR/uv.toml"' in shell_command @@ -897,7 +886,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") # --exclude-newer as a standalone CLI flag should be gone import re @@ -926,7 +915,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") # uv venv line should carry --config-file venv_line = next( @@ -954,7 +943,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") toml_write_pos = shell_command.find("UV_CONFIG_EOF") venv_pos = shell_command.find('"$UV_BIN" venv') @@ -974,7 +963,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "UV_CONFIG_EOF" not in shell_command assert "--config-file" not in shell_command @@ -1000,7 +989,6 @@ def compute(x): result = template_shell_command( str(script_path), "MyClass.compute", # Dotted qualname - "[[1], {}]", ) # The runner should use attrgetter for dotted paths @@ -1032,31 +1020,31 @@ def _write_script(self, tmp_path, content=MINIMAL_SCRIPT): def test_returns_a_string(self, tmp_path): script_path = self._write_script(tmp_path) - result = template_shell_command_parameterized(script_path, "func") + result = template_shell_command(script_path, "func") assert isinstance(result, str) assert len(result) > 0 def test_contains_payload_placeholder_exactly_once(self, tmp_path): script_path = self._write_script(tmp_path) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") assert cmd.count("{payload}") == 1 def test_format_with_payload_kwarg_substitutes_correctly(self, tmp_path): script_path = self._write_script(tmp_path) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") result = cmd.format(payload="__PICKLE__:AAAA==") assert "__PICKLE__:AAAA==" in result assert "{payload}" not in result def test_format_without_payload_kwarg_raises_key_error(self, tmp_path): script_path = self._write_script(tmp_path) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") with pytest.raises(KeyError): cmd.format() def test_base64_payload_is_format_safe(self, tmp_path): script_path = self._write_script(tmp_path) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") base64_payload = "__PICKLE__:ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/==" result = cmd.format(payload=base64_payload) assert base64_payload in result @@ -1076,7 +1064,7 @@ def func(): return {"key": "value"} """ script_path = self._write_script(tmp_path, script_content) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") # Dict braces must be doubled in cmd so .format() doesn't raise KeyError assert '{{"key": "value"}}' in cmd # After .format(), doubled braces collapse to single braces (dict literal preserved) @@ -1085,19 +1073,19 @@ def func(): def test_contains_mktemp_for_file_isolation(self, tmp_path): script_path = self._write_script(tmp_path) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") assert "mktemp -d" in cmd def test_cleanup_uses_rm_rf_task_dir(self, tmp_path): script_path = self._write_script(tmp_path) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") assert 'rm -rf "$TASK_DIR"' in cmd # Individual file cleanup should not appear assert "rm -f " not in cmd def test_file_paths_use_fixed_names_inside_task_dir(self, tmp_path): script_path = self._write_script(tmp_path) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") assert "$TASK_DIR/user_script.py" in cmd assert "$TASK_DIR/runner.py" in cmd assert "$TASK_DIR/payload.in" in cmd @@ -1108,12 +1096,12 @@ def test_file_paths_use_fixed_names_inside_task_dir(self, tmp_path): def test_runner_references_fixed_payload_path(self, tmp_path): script_path = self._write_script(tmp_path) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") assert "open('payload.in'" in cmd def test_includes_standard_uv_and_env_reuse_infrastructure(self, tmp_path): script_path = self._write_script(tmp_path) - cmd = template_shell_command_parameterized(script_path, "func") + cmd = template_shell_command(script_path, "func") assert "ENV_HASH=" in cmd assert "ENV_DIR=" in cmd assert '"$UV_BIN" venv' in cmd @@ -1125,14 +1113,6 @@ def test_different_scripts_produce_different_commands(self, tmp_path): script2 = tmp_path / "script2.py" script1.write_text(MINIMAL_SCRIPT) script2.write_text(MINIMAL_SCRIPT.replace("return 42", "return 99")) - cmd1 = template_shell_command_parameterized(str(script1), "func") - cmd2 = template_shell_command_parameterized(str(script2), "func") + cmd1 = template_shell_command(str(script1), "func") + cmd2 = template_shell_command(str(script2), "func") assert cmd1 != cmd2 - - def test_non_parameterized_template_is_unchanged(self, tmp_path): - """Existing template_shell_command is unaffected by the template changes.""" - script_path = self._write_script(tmp_path) - cmd = template_shell_command(script_path, "func", "test_payload") - assert "test_payload" in cmd - assert "mktemp -d" not in cmd - assert "{payload}" not in cmd From e72b239cdc04b92e996937f90959a580509944c8 Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:30:07 -0600 Subject: [PATCH 04/10] first pass at submit_batch --- src/groundhog_hpc/compute.py | 100 ++++++++++++ src/groundhog_hpc/function.py | 38 ++++- src/groundhog_hpc/future.py | 2 +- tests/conftest.py | 14 +- tests/test_compute.py | 196 +++++++++++++++++++++++ tests/test_function.py | 274 ++++++++++++++++++++------------- tests/test_future.py | 30 ++-- tests/test_mark_import_safe.py | 23 +-- 8 files changed, 536 insertions(+), 141 deletions(-) diff --git a/src/groundhog_hpc/compute.py b/src/groundhog_hpc/compute.py index 0a58415..e5ca8da 100644 --- a/src/groundhog_hpc/compute.py +++ b/src/groundhog_hpc/compute.py @@ -7,7 +7,9 @@ import logging import os +import threading import warnings +from concurrent.futures import Future as ConcurrentFuture from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeVar from uuid import UUID @@ -111,6 +113,104 @@ def submit_to_executor( return deserializing_future +def submit_batch( + endpoint: UUID, + user_endpoint_config: dict[str, Any], + shell_function: ShellFunction, + payloads: list[str], +) -> list[GroundhogFuture]: + """Submit a parameterized ShellFunction as a batch of tasks to a Globus Compute endpoint. + + Registers the ShellFunction once, then submits all payloads as a single batch + request, avoiding per-task API calls that can hit rate limits. + + Args: + endpoint: UUID of the Globus Compute endpoint + user_endpoint_config: Configuration dict for the endpoint + shell_function: The parameterized ShellFunction (registered once for all tasks) + payloads: List of serialized argument strings, one per task + + Returns: + A list of GroundhogFutures in the same order as payloads + """ + client = _get_compute_client() + + config = user_endpoint_config.copy() + if schema := get_endpoint_schema(endpoint): + expected_keys = set(schema.get("properties", {}).keys()) + unexpected_keys = set(config.keys()) - expected_keys + if unexpected_keys: + logger.debug( + f"Filtering unexpected config keys for endpoint {endpoint}: {unexpected_keys}" + ) + config = {k: v for k, v in config.items() if k not in unexpected_keys} + + func_name = getattr(shell_function, "__name__", "unknown") + function_id = client.register_function(shell_function) + logger.info( + f"Registered '{func_name}' for batch submission, function_id={function_id}" + ) + + batch = client.create_batch(user_endpoint_config=config) + for payload in payloads: + batch.add(function_id, kwargs={"payload": payload}) + + response = client.batch_run(endpoint, batch) + task_ids: list[str] = response["tasks"][function_id] + logger.info(f"Batch submitted: {len(task_ids)} tasks to endpoint '{endpoint}'") + + task_id_to_future: dict[str, ConcurrentFuture] = { + tid: ConcurrentFuture() for tid in task_ids + } + + thread = threading.Thread( + target=_poll_batch_results, + args=(dict(task_id_to_future), client), + daemon=True, + ) + thread.start() + + futures = [] + for task_id in task_ids: + gf = GroundhogFuture(task_id_to_future[task_id]) + gf._task_id = task_id + futures.append(gf) + + return futures + + +def _poll_batch_results( + task_id_to_future: dict[str, ConcurrentFuture], + client: Client, + poll_interval: float = 1.0, +) -> None: + """Background thread: poll Globus Compute until all batch tasks are resolved.""" + import time + + pending = dict(task_id_to_future) + + while pending: + results = client.get_batch_result(list(pending.keys())) + for task_id, status in results.items(): + if status.get("pending", True): + continue + fut = pending.pop(task_id) + try: + if "result" in status: + fut.set_result(status["result"]) + else: + try: + status["exception"].reraise() + except Exception as e: + fut.set_exception(e) + except Exception as e: + if not fut.done(): + fut.set_exception(e) + + if pending: + time.sleep(poll_interval) + + def get_task_status(task_id: str | UUID | None) -> dict[str, Any]: """Get the full task status response from Globus Compute. diff --git a/src/groundhog_hpc/function.py b/src/groundhog_hpc/function.py index 84e3330..5c56442 100644 --- a/src/groundhog_hpc/function.py +++ b/src/groundhog_hpc/function.py @@ -12,6 +12,7 @@ import inspect import logging import os +import subprocess import sys import tempfile from pathlib import Path @@ -44,6 +45,35 @@ ShellResult = TypeVar("ShellResult") +def _run_shell_locally(cmd_template: str, payload: str, tmpdir: str) -> Any: + """Execute a parameterized shell command locally. + + Injects GC_TASK_SANDBOX_DIR into the subprocess environment without + mutating os.environ, making concurrent calls thread-safe. + """ + import globus_compute_sdk as gc + + env = {**os.environ, "GC_TASK_SANDBOX_DIR": tmpdir} + cmd = cmd_template.format(payload=payload) + proc = subprocess.run( + cmd, + shell=True, + executable="/bin/bash", + capture_output=True, + text=True, + env=env, + ) + return gc.ShellResult( + cmd=cmd, + stdout=proc.stdout, + stderr=proc.stderr, + returncode=proc.returncode, + exception_name="subprocess.CalledProcessError" + if proc.returncode != 0 + else None, + ) + + class Function: """Wrapper that enables a Python function to be executed remotely on Globus Compute. @@ -268,13 +298,7 @@ def local(self, *args: Any, **kwargs: Any) -> Any: payload = serialize((args, kwargs), proxy_threshold_mb=1.0) with tempfile.TemporaryDirectory() as tmpdir: - # set sandbox dir for ShellFunction to use - if "GC_TASK_SANDBOX_DIR" not in os.environ: - os.environ["GC_TASK_SANDBOX_DIR"] = tmpdir - - # call ShellFunction with payload as a parameter - result = self.shell_function(payload=payload) - assert not isinstance(result, dict) + result = _run_shell_locally(self.shell_function.cmd, payload, tmpdir) if result.returncode != 0: logger.error( diff --git a/src/groundhog_hpc/future.py b/src/groundhog_hpc/future.py index d5597fc..2f09750 100644 --- a/src/groundhog_hpc/future.py +++ b/src/groundhog_hpc/future.py @@ -98,7 +98,7 @@ def task_id(self) -> str | None: Returns the task ID from the underlying Globus Compute future, which may not be populated immediately. """ - return self._original_future.task_id # type: ignore[attr-defined] + return self._task_id or getattr(self._original_future, "task_id", None) @property def endpoint(self) -> str | None: diff --git a/tests/conftest.py b/tests/conftest.py index a8924b4..f95d35a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -307,17 +307,18 @@ def test_something(mock_executor): @pytest.fixture def mock_local_result(): - """Create a mock result for local subprocess execution. + """Create mock objects for local subprocess execution tests. Returns a factory function that creates: - - A mock ShellFunction that returns the result - - The mock result object itself + - A mock ShellFunction with a .cmd attribute (for patching Function.shell_function) + - A mock result object (for patching _run_shell_locally return value) Usage: def test_something(mock_local_result): shell_func, result = mock_local_result(stdout='{"result": 42}') - # Use shell_func in patches - # Use result for specific assertions + with patch.object(Function, "shell_function", new_callable=PropertyMock, return_value=shell_func): + with patch("groundhog_hpc.function._run_shell_locally", return_value=result): + ... """ def _create( @@ -332,7 +333,8 @@ def _create( result.stderr = stderr result.exception_name = exception_name - shell_func = MagicMock(return_value=result) + shell_func = MagicMock() + shell_func.cmd = "test_cmd {payload}" return shell_func, result return _create diff --git a/tests/test_compute.py b/tests/test_compute.py index 8863dfc..3512e44 100644 --- a/tests/test_compute.py +++ b/tests/test_compute.py @@ -4,10 +4,42 @@ from unittest.mock import MagicMock, patch from uuid import UUID +import pytest + from groundhog_hpc.compute import ( + _poll_batch_results, build_shell_function, + submit_batch, submit_to_executor, ) +from groundhog_hpc.future import GroundhogFuture + +_ENDPOINT = "12345678-1234-1234-1234-123456789abc" +_FUNCTION_ID = "ffffffff-ffff-ffff-ffff-ffffffffffff" + + +def _make_shell_function(name="test_func"): + sf = MagicMock() + sf.__name__ = name + return sf + + +def _make_batch_client(function_id=_FUNCTION_ID, task_ids=None): + """Mock GC client pre-configured for batch submission.""" + task_ids = task_ids or ["tid-0", "tid-1"] + client = MagicMock() + client.register_function.return_value = function_id + client.create_batch.return_value = MagicMock() + client.batch_run.return_value = {"tasks": {function_id: task_ids}} + return client + + +def _success(result): + return {"pending": False, "status": "success", "result": result} + + +def _pending(): + return {"pending": True, "status": "unknown"} class TestBuildShellFunction: @@ -127,3 +159,167 @@ def test_walltime_in_config_passed_to_executor( UUID(mock_endpoint_uuid), user_endpoint_config={"account": "test", "walltime": 600}, ) + + +class TestSubmitBatch: + def test_returns_one_future_per_payload(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0", "tid-1", "tid-2"]) + mock_globus_client.return_value = client + + futures = submit_batch( + _ENDPOINT, {}, _make_shell_function(), ["p0", "p1", "p2"] + ) + + assert len(futures) == 3 + assert all(isinstance(f, GroundhogFuture) for f in futures) + + def test_each_future_has_task_id_from_batch_run(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0", "tid-1"]) + mock_globus_client.return_value = client + + futures = submit_batch(_ENDPOINT, {}, _make_shell_function(), ["p0", "p1"]) + + assert futures[0].task_id == "tid-0" + assert futures[1].task_id == "tid-1" + + def test_register_function_called_once(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0", "tid-1", "tid-2"]) + mock_globus_client.return_value = client + shell_fn = _make_shell_function() + + submit_batch(_ENDPOINT, {}, shell_fn, ["p0", "p1", "p2"]) + + client.register_function.assert_called_once_with(shell_fn) + + def test_batch_add_called_once_per_payload_with_payload_kwarg( + self, mock_globus_client + ): + client = _make_batch_client(task_ids=["tid-0", "tid-1"]) + mock_globus_client.return_value = client + batch_mock = client.create_batch.return_value + + submit_batch(_ENDPOINT, {}, _make_shell_function(), ["p0", "p1"]) + + assert batch_mock.add.call_count == 2 + batch_mock.add.assert_any_call(_FUNCTION_ID, kwargs={"payload": "p0"}) + batch_mock.add.assert_any_call(_FUNCTION_ID, kwargs={"payload": "p1"}) + + def test_endpoint_schema_filtering_applied(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0"]) + mock_globus_client.return_value = client + + schema = {"properties": {"account": {"type": "string"}}} + with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=schema): + submit_batch( + _ENDPOINT, + {"account": "proj", "unexpected_key": "val"}, + _make_shell_function(), + ["p0"], + ) + + _, create_batch_kwargs = client.create_batch.call_args + config = create_batch_kwargs["user_endpoint_config"] + assert "account" in config + assert "unexpected_key" not in config + + def test_futures_resolve_via_polling_thread(self, mock_globus_client): + mock_shell_result = MagicMock() + mock_shell_result.returncode = 0 + mock_shell_result.stdout = '"hello"' + mock_shell_result.stderr = "" + + client = _make_batch_client(task_ids=["tid-0"]) + mock_globus_client.return_value = client + + # Resolve the future synchronously by patching _poll_batch_results + def resolve_immediately(task_id_to_future, client, poll_interval=1.0): + task_id_to_future["tid-0"].set_result(mock_shell_result) + + with patch( + "groundhog_hpc.compute._poll_batch_results", side_effect=resolve_immediately + ): + futures = submit_batch(_ENDPOINT, {}, _make_shell_function(), ["p0"]) + + assert futures[0].result(timeout=1) == "hello" + + def test_failed_tasks_propagate_exception(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0"]) + mock_globus_client.return_value = client + + def fail_immediately(task_id_to_future, client, poll_interval=1.0): + task_id_to_future["tid-0"].set_exception(RuntimeError("task blew up")) + + with patch( + "groundhog_hpc.compute._poll_batch_results", side_effect=fail_immediately + ): + futures = submit_batch(_ENDPOINT, {}, _make_shell_function(), ["p0"]) + + with pytest.raises(RuntimeError, match="task blew up"): + futures[0].result(timeout=1) + + +class TestPollBatchResults: + def test_resolves_successful_task(self): + mock_shell_result = MagicMock() + mock_shell_result.returncode = 0 + mock_shell_result.stdout = '"done"' + + fut = Future() + client = MagicMock() + client.get_batch_result.return_value = {"tid-0": _success(mock_shell_result)} + + _poll_batch_results({"tid-0": fut}, client, poll_interval=0) + + assert fut.done() + assert fut.result() is mock_shell_result + + def test_failed_task_sets_exception(self): + mock_exc = MagicMock() + mock_exc.reraise.side_effect = ValueError("remote error") + + fut = Future() + client = MagicMock() + client.get_batch_result.return_value = { + "tid-0": {"pending": False, "status": "failed", "exception": mock_exc} + } + + _poll_batch_results({"tid-0": fut}, client, poll_interval=0) + + assert fut.done() + with pytest.raises(ValueError, match="remote error"): + fut.result() + + def test_pending_task_stays_unresolved_until_next_poll(self): + mock_shell_result = MagicMock() + mock_shell_result.returncode = 0 + mock_shell_result.stdout = '"done"' + + fut = Future() + client = MagicMock() + client.get_batch_result.side_effect = [ + {"tid-0": _pending()}, + {"tid-0": _success(mock_shell_result)}, + ] + + _poll_batch_results({"tid-0": fut}, client, poll_interval=0) + + assert client.get_batch_result.call_count == 2 + assert fut.done() + + def test_polls_only_remaining_pending_tasks(self): + r0, r1 = MagicMock(), MagicMock() + r0.returncode = r1.returncode = 0 + r0.stdout = r1.stdout = '"ok"' + + fut0, fut1 = Future(), Future() + client = MagicMock() + client.get_batch_result.side_effect = [ + {"tid-0": _success(r0), "tid-1": _pending()}, + {"tid-1": _success(r1)}, + ] + + _poll_batch_results({"tid-0": fut0, "tid-1": fut1}, client, poll_interval=0) + + second_call_ids = client.get_batch_result.call_args_list[1][0][0] + assert second_call_ids == ["tid-1"] + assert fut0.done() and fut1.done() diff --git a/tests/test_function.py b/tests/test_function.py index 8b9d036..839f23c 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -352,16 +352,9 @@ class TestLocalMethod: def test_local_executes_function_and_returns_result( self, tmp_path, mock_local_result ): - """Test that local() executes the function via ShellFunction and returns result.""" - # Create a test script + """Test that local() executes the function and returns deserialized result.""" script_path = tmp_path / "test_local.py" - script_content = """import groundhog_hpc as hog - -@hog.function() -def add(a, b): - return a + b -""" - script_path.write_text(script_content) + script_path.write_text("# test") def add(a, b): return a + b @@ -369,8 +362,7 @@ def add(a, b): func = Function(add) func._script_path = str(script_path) - # Create mock result - shell_func, result = mock_local_result(stdout='{"result": 5}') + shell_func, run_result = mock_local_result(stdout='{"result": 5}') with patch.object( Function, @@ -379,9 +371,12 @@ def add(a, b): return_value=shell_func, ): with patch( - "groundhog_hpc.function.deserialize_stdout", return_value=(None, 5) - ) as mock_deserialize: - result_value = func.local(2, 3) + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch( + "groundhog_hpc.function.deserialize_stdout", return_value=(None, 5) + ) as mock_deserialize: + result_value = func.local(2, 3) assert result_value == 5 mock_deserialize.assert_called_once_with('{"result": 5}') @@ -394,7 +389,7 @@ def test_local_serializes_arguments(self, tmp_path, mock_local_result): func = Function(dummy_function) func._script_path = str(script_path) - shell_func, result = mock_local_result(stdout='{"result": "success"}') + shell_func, run_result = mock_local_result(stdout='{"result": "success"}') with patch( "groundhog_hpc.function.serialize", return_value="serialized" @@ -406,120 +401,148 @@ def test_local_serializes_arguments(self, tmp_path, mock_local_result): return_value=shell_func, ): with patch( - "groundhog_hpc.function.deserialize_stdout", - return_value=(None, "success"), + "groundhog_hpc.function._run_shell_locally", return_value=run_result ): - func.local(1, 2, key="value") + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, "success"), + ): + func.local(1, 2, key="value") - # Verify serialize was called with args, kwargs, and proxy_threshold_mb=1.0 mock_serialize.assert_called_once() call_args = mock_serialize.call_args[0][0] call_kwargs = mock_serialize.call_args[1] assert call_args == ((1, 2), {"key": "value"}) assert call_kwargs.get("proxy_threshold_mb") == 1.0 - def test_local_runs_in_temporary_directory(self, tmp_path): - """Test that local() sets GC_TASK_SANDBOX_DIR to a temporary directory.""" + def test_gc_task_sandbox_dir_not_set_on_parent_process( + self, tmp_path, mock_local_result + ): + """local() must not mutate os.environ with GC_TASK_SANDBOX_DIR.""" script_path = tmp_path / "test_local.py" script_path.write_text("# test") func = Function(dummy_function) func._script_path = str(script_path) - mock_result = MagicMock() - mock_result.returncode = 0 - mock_result.stdout = "result" - mock_result.stderr = "" - mock_result.exception_name = None - - mock_shell_function = MagicMock(return_value=mock_result) - - # Store original env var if it exists - original_sandbox_dir = os.environ.get("GC_TASK_SANDBOX_DIR") + shell_func, run_result = mock_local_result() + original = os.environ.pop("GC_TASK_SANDBOX_DIR", None) try: - # Clear it for this test - if "GC_TASK_SANDBOX_DIR" in os.environ: - del os.environ["GC_TASK_SANDBOX_DIR"] - with patch.object( Function, "shell_function", new_callable=PropertyMock, - return_value=mock_shell_function, + return_value=shell_func, ): with patch( - "groundhog_hpc.function.deserialize_stdout", - return_value=(None, "result"), + "groundhog_hpc.function._run_shell_locally", return_value=run_result ): - func.local() - - # Verify GC_TASK_SANDBOX_DIR was set - assert "GC_TASK_SANDBOX_DIR" in os.environ - sandbox_dir = os.environ["GC_TASK_SANDBOX_DIR"] - assert isinstance(sandbox_dir, str) - assert len(sandbox_dir) > 0 + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, "result"), + ): + func.local() + assert "GC_TASK_SANDBOX_DIR" not in os.environ finally: - # Restore original state - if original_sandbox_dir is not None: - os.environ["GC_TASK_SANDBOX_DIR"] = original_sandbox_dir - elif "GC_TASK_SANDBOX_DIR" in os.environ: - del os.environ["GC_TASK_SANDBOX_DIR"] + if original is not None: + os.environ["GC_TASK_SANDBOX_DIR"] = original - def test_local_raises_if_script_path_unavailable(self): - """Test that local() raises ValueError if script path cannot be determined.""" + def test_gc_task_sandbox_dir_not_overwritten_if_already_set( + self, tmp_path, mock_local_result + ): + """local() must not overwrite an externally set GC_TASK_SANDBOX_DIR.""" + script_path = tmp_path / "test_local.py" + script_path.write_text("# test") - def local_func(): - return "test" + func = Function(dummy_function) + func._script_path = str(script_path) - func = Function(local_func) - func._script_path = None + shell_func, run_result = mock_local_result() - # Mock inspect.getfile to raise TypeError (e.g., for built-in functions) - with patch( - "groundhog_hpc.function.inspect.getfile", - side_effect=TypeError("not a file"), - ): - with pytest.raises(ValueError, match="Could not determine script path"): - func.local() + os.environ["GC_TASK_SANDBOX_DIR"] = "/my/custom/dir" + try: + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=shell_func, + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, "result"), + ): + func.local() + + assert os.environ["GC_TASK_SANDBOX_DIR"] == "/my/custom/dir" + finally: + del os.environ["GC_TASK_SANDBOX_DIR"] + + def test_two_concurrent_local_calls_dont_interfere( + self, tmp_path, mock_local_result + ): + """Concurrent local() calls must not share GC_TASK_SANDBOX_DIR via os.environ.""" + import threading - def test_local_uses_shell_function_property(self, tmp_path, mock_local_result): - """Test that local() uses the cached shell_function property.""" script_path = tmp_path / "test_local.py" script_path.write_text("# test") func = Function(dummy_function) func._script_path = str(script_path) - shell_func, result = mock_local_result(stdout="result") + seen_dirs: list[str] = [] + + def capture_env(cmd_template, payload, tmpdir): + seen_dirs.append(os.environ.get("GC_TASK_SANDBOX_DIR", "NOT_SET")) + mock = MagicMock() + mock.returncode = 0 + mock.stdout = '"ok"' + mock.stderr = "" + mock.exception_name = None + return mock + + shell_func, _ = mock_local_result() + + os.environ.pop("GC_TASK_SANDBOX_DIR", None) with patch.object( Function, "shell_function", new_callable=PropertyMock, return_value=shell_func, - ) as mock_sf_prop: + ): with patch( - "groundhog_hpc.function.deserialize_stdout", - return_value=(None, "result"), + "groundhog_hpc.function._run_shell_locally", side_effect=capture_env ): - func.local() + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, "ok"), + ): + threads = [threading.Thread(target=func.local) for _ in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() - # Verify shell_function property was accessed (not script_to_submittable) - mock_sf_prop.assert_called() + # Neither thread should have seen GC_TASK_SANDBOX_DIR in os.environ + assert all(d == "NOT_SET" for d in seen_dirs) + assert "GC_TASK_SANDBOX_DIR" not in os.environ - def test_local_calls_shell_function_with_payload_kwarg( + def test_local_passes_tmpdir_to_run_shell_locally( self, tmp_path, mock_local_result ): - """Test that local() calls shell_function(payload=...) not shell_function().""" + """local() passes a real tmpdir path to _run_shell_locally.""" script_path = tmp_path / "test_local.py" script_path.write_text("# test") func = Function(dummy_function) func._script_path = str(script_path) - shell_func, result = mock_local_result(stdout="result") + shell_func, run_result = mock_local_result() with patch.object( Function, @@ -527,41 +550,78 @@ def test_local_calls_shell_function_with_payload_kwarg( new_callable=PropertyMock, return_value=shell_func, ): - with patch("groundhog_hpc.function.serialize", return_value="ABC123"): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ) as mock_run: + with patch("groundhog_hpc.function.serialize", return_value="PAYLOAD"): + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, "result"), + ): + func.local() + + # Third argument is tmpdir; second is the serialized payload + _, call_payload, call_tmpdir = mock_run.call_args[0] + assert call_payload == "PAYLOAD" + assert isinstance(call_tmpdir, str) and len(call_tmpdir) > 0 + + def test_local_raises_if_script_path_unavailable(self): + """Test that local() raises ValueError if script path cannot be determined.""" + + def local_func(): + return "test" + + func = Function(local_func) + func._script_path = None + + with patch( + "groundhog_hpc.function.inspect.getfile", + side_effect=TypeError("not a file"), + ): + with pytest.raises(ValueError, match="Could not determine script path"): + func.local() + + def test_local_uses_shell_function_property(self, tmp_path, mock_local_result): + """local() accesses the cached shell_function property for .cmd.""" + script_path = tmp_path / "test_local.py" + script_path.write_text("# test") + + func = Function(dummy_function) + func._script_path = str(script_path) + + shell_func, run_result = mock_local_result(stdout="result") + + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=shell_func, + ) as mock_sf_prop: + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): with patch( "groundhog_hpc.function.deserialize_stdout", return_value=(None, "result"), ): func.local() - # Verify ShellFunction was called with payload as keyword argument - shell_func.assert_called_once() - assert shell_func.call_args[1]["payload"] == "ABC123" + mock_sf_prop.assert_called() - def test_local_infers_script_path_from_function(self, tmp_path): + def test_local_infers_script_path_from_function(self, tmp_path, mock_local_result): """Test that local() can infer script path from function's source file.""" - # Create a test script script_path = tmp_path / "inferred_script.py" - script_content = """def my_function(): - return 42 -""" - script_path.write_text(script_content) + script_path.write_text("def my_function():\n return 42\n") def my_function(): return 42 func = Function(my_function) - func._script_path = None # Force it to infer - - mock_result = MagicMock() - mock_result.returncode = 0 - mock_result.stdout = "42" - mock_result.stderr = "" - mock_result.exception_name = None + func._script_path = None - mock_shell_function = MagicMock(return_value=mock_result) + shell_func, run_result = mock_local_result(stdout="42") + run_result.returncode = 0 - # Mock inspect.getfile to return our test script with patch( "groundhog_hpc.function.inspect.getfile", return_value=str(script_path) ): @@ -569,12 +629,16 @@ def my_function(): Function, "shell_function", new_callable=PropertyMock, - return_value=mock_shell_function, + return_value=shell_func, ): with patch( - "groundhog_hpc.function.deserialize_stdout", return_value=(None, 42) + "groundhog_hpc.function._run_shell_locally", return_value=run_result ): - result = func.local() + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, 42), + ): + result = func.local() assert result == 42 @@ -723,9 +787,8 @@ def test_func(x): test_module = sys.modules[func._wrapped_function.__module__] test_module.__groundhog_imported__ = True - shell_func, result = mock_local_result(stdout="84") + shell_func, run_result = mock_local_result(stdout="84") - # Patch shell_function property to verify subprocess is used with patch.object( Function, "shell_function", @@ -733,10 +796,13 @@ def test_func(x): return_value=shell_func, ): with patch( - "groundhog_hpc.function.deserialize_stdout", return_value=(None, 84) - ): - result_value = func.local(42) + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ) as mock_run: + with patch( + "groundhog_hpc.function.deserialize_stdout", return_value=(None, 84) + ): + result_value = func.local(42) - # Should always use subprocess (ShellFunction) + # Always uses _run_shell_locally (never calls the function directly) assert result_value == 84 - shell_func.assert_called_once() + mock_run.assert_called_once() diff --git a/tests/test_future.py b/tests/test_future.py index db1d884..693b42a 100644 --- a/tests/test_future.py +++ b/tests/test_future.py @@ -88,27 +88,29 @@ def test_handles_shell_execution_errors(self): assert exc_info.value.returncode == 1 assert "something went wrong" in exc_info.value.stderr - def test_preserves_task_id(self): - """Test that task_id attribute is preserved on the deserializing future.""" + def test_task_id_falls_through_to_original_future(self): + """task_id reads from original future when _task_id is None.""" original = Future() - original.task_id = "test-task-123" + original.task_id = "abc-123" deserializing = GroundhogFuture(original) - # Create a successful result - mock_shell_result = MagicMock() - mock_shell_result.returncode = 0 - mock_shell_result.stdout = '"test"' + assert deserializing.task_id == "abc-123" - original.set_result(mock_shell_result) + def test_explicit_task_id_takes_precedence(self): + """_task_id takes precedence over the original future's task_id attribute.""" + original = Future() + original.task_id = "from-future" + deserializing = GroundhogFuture(original) + deserializing._task_id = "explicit" - # Wait for callback - import time + assert deserializing.task_id == "explicit" - time.sleep(0.01) + def test_task_id_returns_none_when_neither_source_has_it(self): + """task_id returns None without raising when the underlying future has no task_id.""" + original = Future() # plain Future, no task_id attribute + deserializing = GroundhogFuture(original) - # Task ID should be preserved - assert hasattr(deserializing, "task_id") - assert deserializing.task_id == "test-task-123" + assert deserializing.task_id is None def test_shell_result_property_returns_raw_result(self): """Test that shell_result property provides access to raw ShellResult.""" diff --git a/tests/test_mark_import_safe.py b/tests/test_mark_import_safe.py index 1259c22..5ca56f1 100644 --- a/tests/test_mark_import_safe.py +++ b/tests/test_mark_import_safe.py @@ -174,13 +174,14 @@ def my_func(): # Verify flag is set assert module.__groundhog_imported__ is True - # Mock shell_function property to avoid actual subprocess execution + # Mock _run_shell_locally to avoid actual subprocess execution mock_shell_func = Mock() - mock_result = Mock() - mock_result.returncode = 0 - mock_result.stdout = 'hello\n__GROUNDHOG_RESULT__\n"hello"' - mock_result.stderr = "" - mock_shell_func.return_value = mock_result + mock_shell_func.cmd = "test {payload}" + mock_run_result = Mock() + mock_run_result.returncode = 0 + mock_run_result.stdout = 'hello\n__GROUNDHOG_RESULT__\n"hello"' + mock_run_result.stderr = "" + mock_run_result.exception_name = None from groundhog_hpc.function import Function @@ -190,9 +191,13 @@ def my_func(): new_callable=PropertyMock, return_value=mock_shell_func, ): - # Now .local() should work (won't raise ModuleImportError) - result = module.my_func.local() - assert result == "hello" + with patch( + "groundhog_hpc.function._run_shell_locally", + return_value=mock_run_result, + ): + # Now .local() should work (won't raise ModuleImportError) + result = module.my_func.local() + assert result == "hello" # Cleanup del sys.modules["test_module5"] From 320e648fba0b21e78025606ea308d856172a3ba9 Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:35:06 -0600 Subject: [PATCH 05/10] `function.batch_submit` implemented --- src/groundhog_hpc/function.py | 74 ++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/src/groundhog_hpc/function.py b/src/groundhog_hpc/function.py index 5c56442..83395a1 100644 --- a/src/groundhog_hpc/function.py +++ b/src/groundhog_hpc/function.py @@ -10,6 +10,7 @@ """ import inspect +import itertools import logging import os import subprocess @@ -20,7 +21,7 @@ from typing import TYPE_CHECKING, Any, TypeVar from uuid import UUID -from groundhog_hpc.compute import build_shell_function, submit_to_executor +from groundhog_hpc.compute import build_shell_function, submit_batch, submit_to_executor from groundhog_hpc.configuration.resolver import ConfigResolver from groundhog_hpc.console import display_task_status from groundhog_hpc.errors import ( @@ -332,6 +333,77 @@ def local(self, *args: Any, **kwargs: Any) -> Any: print(user_stdout, file=sys.stdout) return deserialized_result + def batch_submit( + self, + args: list[tuple] = [], + kwargs: list[dict] = [], + endpoint: str | None = None, + user_endpoint_config: dict[str, Any] | None = None, + ) -> list[GroundhogFuture]: + """Submit the function for asynchronous remote execution as a batch. + + Submits all tasks as a single Globus Compute batch request, avoiding + per-task API calls that can hit rate limits. + + Args: + args: List of positional-argument tuples, one per task + kwargs: List of keyword-argument dicts, one per task + endpoint: Globus Compute endpoint UUID or named endpoint + user_endpoint_config: Endpoint configuration dict + + Returns: + A list of GroundhogFutures in the same order as the input tasks + + Raises: + ModuleImportError: If called during module import + ValueError: If both args and kwargs are empty + """ + module = sys.modules.get(self._wrapped_function.__module__) + if not getattr(module, "__groundhog_imported__", False): + raise ModuleImportError( + self._wrapped_function.__name__, + "batch_submit", + self._wrapped_function.__module__, + ) + + if max(len(args), len(kwargs)) == 0: + raise ValueError( + "batch_submit requires at least one task: args and kwargs are both empty" + ) + + endpoint = endpoint or self.endpoint + decorator_config = self.default_user_endpoint_config.copy() + call_time_config = user_endpoint_config.copy() if user_endpoint_config else {} + config = self.config_resolver.resolve( + endpoint_name=endpoint or "", + decorator_config=decorator_config, + call_time_config=call_time_config, + ) + if "endpoint" in config: + endpoint = config.pop("endpoint") + if not endpoint: + available_endpoints = self._get_available_endpoints_from_pep723() + if available_endpoints: + endpoints_str = ", ".join(f"'{e}'" for e in available_endpoints) + raise ValueError( + f"No endpoint specified. Available endpoints found in config: {endpoints_str}." + ) + raise ValueError("No endpoint specified") + + payloads = [] + for a, kw in itertools.zip_longest(args, kwargs, fillvalue=None): + a = a if a is not None else () + kw = kw if kw is not None else {} + payloads.append( + serialize((a, kw), use_proxy=False, proxy_threshold_mb=None) + ) + + futures = submit_batch(UUID(endpoint), config, self.shell_function, payloads) + for future in futures: + future.function_name = self.name + future.user_endpoint_config = config + return futures + @property def shell_command(self) -> str: """Parameterized shell command string with a {payload} placeholder. From ad71650b6cfcfc3e08b4a6a846df57d0cff2de6e Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:41:42 -0600 Subject: [PATCH 06/10] fix mutable defaults --- src/groundhog_hpc/function.py | 8 ++- tests/test_function.py | 127 ++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 3 deletions(-) diff --git a/src/groundhog_hpc/function.py b/src/groundhog_hpc/function.py index 83395a1..4f40bb6 100644 --- a/src/groundhog_hpc/function.py +++ b/src/groundhog_hpc/function.py @@ -335,8 +335,8 @@ def local(self, *args: Any, **kwargs: Any) -> Any: def batch_submit( self, - args: list[tuple] = [], - kwargs: list[dict] = [], + args: list[tuple] | None = None, + kwargs: list[dict] | None = None, endpoint: str | None = None, user_endpoint_config: dict[str, Any] | None = None, ) -> list[GroundhogFuture]: @@ -349,7 +349,7 @@ def batch_submit( args: List of positional-argument tuples, one per task kwargs: List of keyword-argument dicts, one per task endpoint: Globus Compute endpoint UUID or named endpoint - user_endpoint_config: Endpoint configuration dict + user_endpoint_config: Endpoint configuration dict (merged with decorator default) Returns: A list of GroundhogFutures in the same order as the input tasks @@ -358,6 +358,8 @@ def batch_submit( ModuleImportError: If called during module import ValueError: If both args and kwargs are empty """ + args = args or [] + kwargs = kwargs or [] module = sys.modules.get(self._wrapped_function.__module__) if not getattr(module, "__groundhog_imported__", False): raise ModuleImportError( diff --git a/tests/test_function.py b/tests/test_function.py index 839f23c..4a9a3e7 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -806,3 +806,130 @@ def test_func(x): # Always uses _run_shell_locally (never calls the function directly) assert result_value == 84 mock_run.assert_called_once() + + +class TestBatchSubmit: + """Tests for Function.batch_submit().""" + + def _make_func(self, tmp_path, mock_endpoint_uuid): + script_path = tmp_path / "test_script.py" + script_path.write_text("# test") + func = Function(dummy_function, endpoint=mock_endpoint_uuid) + func._script_path = str(script_path) + return func + + def _mock_submit_batch(self, n=3): + """Return a mock submit_batch that produces n GroundhogFutures.""" + from concurrent.futures import Future as CF + + futures = [] + for i in range(n): + cf = CF() + cf.set_result(MagicMock(returncode=0, stdout=f'"{i}"', stderr="")) + gf = MagicMock() + gf._task_id = f"tid-{i}" + futures.append(gf) + return MagicMock(return_value=futures), futures + + def test_raises_without_import_flag(self, tmp_path, mock_endpoint_uuid): + import sys + + func = self._make_func(tmp_path, mock_endpoint_uuid) + test_module = sys.modules.get("tests.test_fixtures") + had_flag = hasattr(test_module, "__groundhog_imported__") + if had_flag: + del test_module.__groundhog_imported__ + try: + with pytest.raises(ModuleImportError): + func.batch_submit(args=[(1,)]) + finally: + if had_flag: + test_module.__groundhog_imported__ = True + + def test_raises_when_args_and_kwargs_both_empty(self, tmp_path, mock_endpoint_uuid): + func = self._make_func(tmp_path, mock_endpoint_uuid) + with pytest.raises(ValueError, match="both empty"): + func.batch_submit() + + def test_returns_one_future_per_task( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=3) + with patch("groundhog_hpc.function.submit_batch", mock_batch): + result = func.batch_submit(args=[(1,), (2,), (3,)]) + assert len(result) == 3 + + def test_args_and_kwargs_zipped_with_fill( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=2) + captured = [] + + def fake_serialize(data, **kw): + captured.append(data) + return f"payload_{len(captured)}" + + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", side_effect=fake_serialize): + func.batch_submit(args=[(1,), (2,)], kwargs=[{"k": "v"}]) + + assert captured[0] == ((1,), {"k": "v"}) + assert captured[1] == ((2,), {}) + + def test_kwargs_only_batch_uses_empty_args_tuple( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=2) + captured = [] + + def fake_serialize(data, **kw): + captured.append(data) + return "p" + + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", side_effect=fake_serialize): + func.batch_submit(kwargs=[{"x": 1}, {"x": 2}]) + + assert captured[0] == ((), {"x": 1}) + assert captured[1] == ((), {"x": 2}) + + def test_uses_resolved_endpoint( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=1) + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", return_value="p"): + func.batch_submit(args=[(1,)]) + endpoint_arg = mock_batch.call_args[0][0] + from uuid import UUID + + assert endpoint_arg == UUID(mock_endpoint_uuid) + + def test_callsite_endpoint_overrides_decorator( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + other_uuid = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=1) + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", return_value="p"): + func.batch_submit(args=[(1,)], endpoint=other_uuid) + from uuid import UUID + + assert mock_batch.call_args[0][0] == UUID(other_uuid) + + def test_function_name_and_config_set_on_each_future( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, mock_futures = self._mock_submit_batch(n=3) + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", return_value="p"): + result = func.batch_submit(args=[(1,), (2,), (3,)]) + for f in result: + assert f.function_name == func.name + assert f.user_endpoint_config is not None From ee94fe929a8445d5c0234502de40b0956185a052 Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Fri, 6 Mar 2026 12:49:12 -0600 Subject: [PATCH 07/10] `Function.batch_local` and executor kwargs --- src/groundhog_hpc/compute.py | 6 +- src/groundhog_hpc/function.py | 61 ++++++++- tests/test_function.py | 238 ++++++++++++++++++++++++++++++++++ 3 files changed, 303 insertions(+), 2 deletions(-) diff --git a/src/groundhog_hpc/compute.py b/src/groundhog_hpc/compute.py index e5ca8da..03d0d9d 100644 --- a/src/groundhog_hpc/compute.py +++ b/src/groundhog_hpc/compute.py @@ -74,6 +74,7 @@ def submit_to_executor( user_endpoint_config: dict[str, Any], shell_function: ShellFunction, payload: str, + executor_kwargs: dict[str, Any] | None = None, ) -> GroundhogFuture: """Submit a ShellFunction to a Globus Compute endpoint for execution. @@ -82,6 +83,7 @@ def submit_to_executor( user_endpoint_config: Configuration dict for the endpoint (e.g., worker_init, walltime) shell_function: The parameterized ShellFunction to execute payload: Serialized arguments string, substituted into the {payload} placeholder + executor_kwargs: Extra keyword arguments forwarded directly to gc.Executor constructor Returns: A GroundhogFuture that will contain the deserialized result @@ -100,7 +102,9 @@ def submit_to_executor( config = {k: v for k, v in config.items() if k not in unexpected_keys} logger.debug(f"Creating Globus Compute executor for endpoint {endpoint}") - with gc.Executor(endpoint, user_endpoint_config=config) as executor: + with gc.Executor( + endpoint, user_endpoint_config=config, **(executor_kwargs or {}) + ) as executor: func_name = getattr( shell_function, "__name__", getattr(shell_function, "name", "unknown") ) diff --git a/src/groundhog_hpc/function.py b/src/groundhog_hpc/function.py index 4f40bb6..6c767af 100644 --- a/src/groundhog_hpc/function.py +++ b/src/groundhog_hpc/function.py @@ -16,6 +16,7 @@ import subprocess import sys import tempfile +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from types import FunctionType from typing import TYPE_CHECKING, Any, TypeVar @@ -46,7 +47,7 @@ ShellResult = TypeVar("ShellResult") -def _run_shell_locally(cmd_template: str, payload: str, tmpdir: str) -> Any: +def _run_shell_locally(cmd_template: str, payload: str, tmpdir: str) -> ShellResult: """Execute a parameterized shell command locally. Injects GC_TASK_SANDBOX_DIR into the subprocess environment without @@ -148,6 +149,7 @@ def submit( *args: Any, endpoint: str | None = None, user_endpoint_config: dict[str, Any] | None = None, + executor_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> GroundhogFuture: """Submit the function for asynchronous remote execution. @@ -157,6 +159,7 @@ def submit( endpoint: Globus Compute endpoint UUID (or named endpoint from `[tool.hog.]` PEP 723 metadata). Replaces decorator default. user_endpoint_config: Endpoint configuration dict (merged with decorator default) + executor_kwargs: Keyword arguments forwarded to Globus Compute Executor **kwargs: Keyword arguments to pass to the function Returns: @@ -221,6 +224,7 @@ def submit( user_endpoint_config=config, shell_function=self.shell_function, payload=payload, + executor_kwargs=executor_kwargs, ) future.endpoint = endpoint future.user_endpoint_config = config @@ -232,6 +236,7 @@ def remote( *args: Any, endpoint: str | None = None, user_endpoint_config: dict[str, Any] | None = None, + executor_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Execute the function remotely and block until completion. @@ -244,6 +249,7 @@ def remote( endpoint: Globus Compute endpoint UUID (or named endpoint from `[tool.hog.]` PEP 723 metadata). Replaces decorator default. user_endpoint_config: Endpoint configuration dict (merged with decorator default) + executor_kwargs: Keyword arguments forwarded to Globus Compute Executor **kwargs: Keyword arguments to pass to the function Returns: @@ -260,6 +266,7 @@ def remote( *args, endpoint=endpoint, user_endpoint_config=user_endpoint_config, + executor_kwargs=executor_kwargs, **kwargs, ) display_task_status(future) @@ -406,6 +413,58 @@ def batch_submit( future.user_endpoint_config = config return futures + def batch_local( + self, + args: list[tuple] = [], + kwargs: list[dict] = [], + executor_kwargs: dict[str, Any] | None = None, + ) -> list[GroundhogFuture]: + """Execute the function locally in parallel subprocesses for each task. + + Submits all tasks to a ThreadPoolExecutor immediately and returns futures + without waiting for completion. Each task runs in its own subprocess with + an isolated temporary directory. + + Args: + args: List of positional-argument tuples, one per task + kwargs: List of keyword-argument dicts, one per task + executor_kwargs: Keyword arguments forwarded to ThreadPoolExecutor + + Returns: + A list of GroundhogFutures in the same order as the input tasks + + Raises: + ModuleImportError: If called during module import + ValueError: If both args and kwargs are empty + """ + module = sys.modules.get(self._wrapped_function.__module__) + if not getattr(module, "__groundhog_imported__", False): + raise ModuleImportError( + self._wrapped_function.__name__, + "batch_local", + self._wrapped_function.__module__, + ) + + if max(len(args), len(kwargs)) == 0: + raise ValueError( + "batch_local requires at least one task: args and kwargs are both empty" + ) + + payloads = [] + for a, kw in itertools.zip_longest(args, kwargs, fillvalue=None): + a = a if a is not None else () + kw = kw if kw is not None else {} + payloads.append(serialize((a, kw), proxy_threshold_mb=1.0)) + + cmd_template = self.shell_function.cmd + + def _worker(payload: str) -> ShellResult: + with tempfile.TemporaryDirectory() as tmpdir: + return _run_shell_locally(cmd_template, payload, tmpdir) + + executor = ThreadPoolExecutor(**(executor_kwargs or {})) + return [GroundhogFuture(executor.submit(_worker, p)) for p in payloads] + @property def shell_command(self) -> str: """Parameterized shell command string with a {payload} placeholder. diff --git a/tests/test_function.py b/tests/test_function.py index 4a9a3e7..2ca1084 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,12 +1,14 @@ """Tests for the Function class.""" import os +from concurrent.futures import ThreadPoolExecutor from unittest.mock import MagicMock, PropertyMock, patch import pytest from groundhog_hpc.errors import ModuleImportError from groundhog_hpc.function import Function +from groundhog_hpc.future import GroundhogFuture from tests.test_fixtures import simple_function # Alias for backward compatibility with existing tests @@ -345,6 +347,46 @@ def test_default_worker_init_preserved_when_no_callsite_override( assert "worker_init" in config assert default_worker_init in config["worker_init"] + def test_executor_kwargs_forwarded_to_submit_to_executor( + self, function_with_script, mock_submission_stack + ): + """Test that executor_kwargs are forwarded to submit_to_executor.""" + func = function_with_script() + + func.submit(executor_kwargs={"amqp_port": 5671}) + + mock_submit = mock_submission_stack["submit_to_executor"] + assert mock_submit.call_args[1]["executor_kwargs"] == {"amqp_port": 5671} + + def test_executor_kwargs_defaults_to_none( + self, function_with_script, mock_submission_stack + ): + """Test that executor_kwargs defaults to None when not provided.""" + func = function_with_script() + + func.submit() + + mock_submit = mock_submission_stack["submit_to_executor"] + assert mock_submit.call_args[1]["executor_kwargs"] is None + + def test_executor_kwargs_does_not_bleed_into_user_endpoint_config( + self, function_with_script, mock_submission_stack + ): + """Test that executor_kwargs keys are not added to user_endpoint_config.""" + func = function_with_script() + + mock_schema = {"properties": {"account": {"type": "string"}}} + mock_submission_stack["get_endpoint_schema"].return_value = mock_schema + + func.submit( + executor_kwargs={"amqp_port": 5671}, + user_endpoint_config={"account": "x"}, + ) + + mock_submit = mock_submission_stack["submit_to_executor"] + config = mock_submit.call_args[1]["user_endpoint_config"] + assert "amqp_port" not in config + class TestLocalMethod: """Test the local() method for running functions in local subprocess.""" @@ -933,3 +975,199 @@ def test_function_name_and_config_set_on_each_future( for f in result: assert f.function_name == func.name assert f.user_endpoint_config is not None + + +class TestBatchLocal: + """Tests for Function.batch_local().""" + + def _make_func(self, tmp_path): + script_path = tmp_path / "test_script.py" + script_path.write_text("# test") + func = Function(dummy_function) + func._script_path = str(script_path) + return func + + def _mock_shell_func(self): + sf = MagicMock() + sf.cmd = "test_cmd {payload}" + return sf + + def _make_run_result(self, stdout='"ok"'): + r = MagicMock() + r.returncode = 0 + r.stdout = stdout + r.stderr = "" + r.exception_name = None + return r + + def test_raises_without_import_flag(self, tmp_path): + import sys + + func = self._make_func(tmp_path) + test_module = sys.modules.get("tests.test_fixtures") + had_flag = hasattr(test_module, "__groundhog_imported__") + if had_flag: + del test_module.__groundhog_imported__ + try: + with pytest.raises(ModuleImportError): + func.batch_local(args=[(1,)]) + finally: + if had_flag: + test_module.__groundhog_imported__ = True + + def test_raises_when_args_and_kwargs_both_empty(self, tmp_path): + func = self._make_func(tmp_path) + with pytest.raises(ValueError, match="both empty"): + func.batch_local() + + def test_returns_one_future_per_task(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + futures = func.batch_local(args=[(1,), (2,)]) + assert len(futures) == 2 + assert all(isinstance(f, GroundhogFuture) for f in futures) + + def test_returns_immediately_without_blocking(self, tmp_path): + import threading + + func = self._make_func(tmp_path) + started = threading.Event() + finished = threading.Event() + + def slow_run(cmd_template, payload, tmpdir): + started.set() + finished.wait(timeout=2) + return self._make_run_result() + + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", side_effect=slow_run + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + futures = func.batch_local(args=[(1,)]) + + # batch_local returned before the worker finished + assert len(futures) == 1 + assert not futures[0].done() + finished.set() + + def test_args_and_kwargs_zipped_with_fill(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + captured = [] + + def fake_serialize(data, **kw): + captured.append(data) + return "p" + + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch( + "groundhog_hpc.function.serialize", side_effect=fake_serialize + ): + func.batch_local(args=[(1,), (2,)], kwargs=[{"k": "v"}]) + + assert captured[0] == ((1,), {"k": "v"}) + assert captured[1] == ((2,), {}) + + def test_executor_kwargs_passed_to_thread_pool(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + with patch( + "groundhog_hpc.function.ThreadPoolExecutor", + wraps=ThreadPoolExecutor, + ) as mock_tpe: + func.batch_local( + args=[(1,)], executor_kwargs={"max_workers": 2} + ) + mock_tpe.assert_called_once_with(max_workers=2) + + def test_gc_task_sandbox_dir_not_set_on_parent_process(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + os.environ.pop("GC_TASK_SANDBOX_DIR", None) + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + futures = func.batch_local(args=[(1,)]) + # Wait for workers to finish + for f in futures: + try: + f.result(timeout=2) + except Exception: + pass + assert "GC_TASK_SANDBOX_DIR" not in os.environ + + def test_task_id_is_none_for_all_local_futures(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + futures = func.batch_local(args=[(1,), (2,)]) + for f in futures: + assert f.task_id is None + + def test_serialize_called_once_per_task(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch( + "groundhog_hpc.function.serialize", return_value="p" + ) as mock_ser: + func.batch_local(args=[(1,), (2,), (3,)]) + assert mock_ser.call_count == 3 From 855552143a5bdede15ebc22e2222b79643e67b3f Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Fri, 6 Mar 2026 12:50:56 -0600 Subject: [PATCH 08/10] docs/examples --- docs/api/function.md | 2 + docs/examples/parallel-execution.md | 100 +++++++++++++++++++++++++--- examples/parallel_execution.py | 34 ++++++++-- 3 files changed, 120 insertions(+), 16 deletions(-) diff --git a/docs/api/function.md b/docs/api/function.md index d3179c1..11913d4 100644 --- a/docs/api/function.md +++ b/docs/api/function.md @@ -14,6 +14,8 @@ The `Function` class wraps user functions decorated with `@hog.function()` to en - remote - submit - local + - batch_submit + - batch_local ## Method Class diff --git a/docs/examples/parallel-execution.md b/docs/examples/parallel-execution.md index 14738eb..880b4ca 100644 --- a/docs/examples/parallel-execution.md +++ b/docs/examples/parallel-execution.md @@ -1,6 +1,6 @@ # Parallel Execution -This example demonstrates the difference between sequential execution with `.remote()` and parallel execution with `.submit()`. +This example demonstrates sequential execution with `.remote()`, parallel execution with `.submit()`, and batch execution with `.batch_submit()` and `.batch_local()`. ## When to Use Each Method @@ -20,7 +20,18 @@ This example demonstrates the difference between sequential execution with `.rem - You don't care for the console display - You need access to the `GroundhogFuture` object -## Full Example +**Use `.batch_submit()` when:** + +- You're submitting many tasks to the same remote endpoint +- You want to avoid Globus Compute rate limits (batching is one API call instead of N) +- All tasks use the same function with different arguments + +**Use `.batch_local()` when:** + +- You want to run many tasks in parallel locally +- You want immediate `GroundhogFuture`s instead of `.local()`'s blocking behavior + +## Example: Remote vs Submit ```python title="parallel_execution.py" # /// script @@ -28,7 +39,7 @@ This example demonstrates the difference between sequential execution with `.rem # dependencies = [] # # [tool.uv] -# exclude-newer = "2025-12-02T19:48:40Z" +# exclude-newer = "2026-03-06T00:00:00Z" # # [tool.hog.anvil] # endpoint = "5aafb4c1-27b2-40d8-a038-a0277611868f" @@ -66,6 +77,28 @@ def main(): results = [f.result() for f in futures] # (3)! print(f" Results: {results}") print(f" Time: {time.time() - start:.1f}s (approximately 2s)") + + +@hog.harness() +def batch(): + """Run with: hog run parallel_execution.py batch""" + # .batch_submit() registers the function once and sends all tasks in a + # single API request, avoiding the per-task rate limits of a .submit() loop. + print("Batch remote submission:") + futures = slow_square.batch_submit( + args=[(0,), (1,), (2,), (3,), (4,)], + ) + results = [f.result() for f in futures] + print(f" Results: {results}") # [0, 1, 4, 9, 16] + + # .batch_local() runs each task in its own subprocess in parallel. + print("Batch local execution:") + futures = slow_square.batch_local( + args=[(0,), (1,), (2,), (3,), (4,)], + executor_kwargs={"max_workers": 4}, + ) + results = [f.result() for f in futures] + print(f" Results: {results}") # [0, 1, 4, 9, 16] ``` 1. `.remote()` blocks until the function completes. Each call waits for the previous one to finish. Total time: 3 tasks x 2 seconds = ~6 seconds. @@ -74,36 +107,81 @@ def main(): 3. Calling `.result()` on each future blocks until that task completes. Since all tasks run in parallel, total time is ~2 seconds. + +## Example: Batching Locally / Remotely + +A loop of `.submit()` calls makes one API request per task and can hit Globus Compute rate limits at large N. `.batch_submit()` registers the function once and sends all tasks in a single request. + +```python +# Instead of this (N separate API calls): +futures = [slow_square.submit(i) for i in range(5)] + +# Use batch_submit (one API call): +futures = slow_square.batch_submit( + args=[(0,), (1,), (2,), (3,), (4,)], # (1)! +) +results = [f.result() for f in futures] +# [0, 1, 4, 9, 16] +``` + +1. Each tuple is unpacked as positional arguments for one task. Pass `kwargs=[...]` alongside `args` to mix positional and keyword arguments — when the two lists have different lengths, the shorter one fills with `()` or `{}`. + +`.batch_local()` runs each task in its own subprocess with an isolated temporary directory: + +```python +futures = slow_square.batch_local( + args=[(0,), (1,), (2,), (3,), (4,)], + executor_kwargs={"max_workers": 4}, # (1)! +) +results = [f.result() for f in futures] +# [0, 1, 4, 9, 16] +``` + +1. `executor_kwargs` is forwarded directly to `ThreadPoolExecutor`. Omit it to use the default worker count. + ## Working with GroundhogFutures +`.submit()` and both batch methods return `GroundhogFuture` objects. They behave like standard `concurrent.futures.Future` objects, with additional Groundhog-specific properties. + ```python future = slow_square.submit(5) +# Get the deserialized return value (blocks until ready) +result = future.result() +result = future.result(timeout=10) # Raises TimeoutError if not ready + # Check if done (non-blocking) if future.done(): print("Task completed!") -# Get the result (blocks until ready) -result = future.result() - -# Get result with timeout -result = future.result(timeout=10) # Raises TimeoutError if not ready - # Cancel a pending task future.cancel() -# Inspect the underlying ShellResult +# Inspect raw shell execution metadata print(future.shell_result.returncode) print(future.shell_result.stderr) + +# Capture stdout from print() calls inside the remote function +if future.user_stdout: + print(future.user_stdout) + +# Inspect the resolved configuration that was actually passed to the endpoint +print(future.user_endpoint_config) # {"account": "...", "partition": "..."} +print(future.task_id) # Globus Compute task ID +print(future.function_name) # "slow_square" ``` ## Running the Example ```bash +# Remote vs submit timing comparison hog run examples/parallel_execution.py + +# Batch submission and local parallel execution +hog run examples/parallel_execution.py batch ``` -Expected output: +Expected output from `main`: ``` Sequential execution with .remote(): diff --git a/examples/parallel_execution.py b/examples/parallel_execution.py index 47d08dc..1562e83 100644 --- a/examples/parallel_execution.py +++ b/examples/parallel_execution.py @@ -3,18 +3,19 @@ # dependencies = [] # # [tool.uv] -# exclude-newer = "2025-12-02T19:48:40Z" +# exclude-newer = "2026-03-06T00:00:00Z" # # [tool.hog.anvil] # endpoint = "5aafb4c1-27b2-40d8-a038-a0277611868f" # account = "cis250461" -# requirements = "" # /// """ -Example showing parallel execution with .submit() vs sequential with .remote(). +Example showing parallel and batch execution patterns. Use .remote() when you want to wait for each result before continuing. Use .submit() when you want to run multiple tasks in parallel. +Use .batch_submit() to submit many tasks without hitting rate limits. +Use .batch_local() for parallel local execution on the login node. """ import groundhog_hpc as hog @@ -39,12 +40,35 @@ def main(): start = time.time() results = [slow_square.remote(i) for i in range(3)] print(f" Results: {results}") - print(f" Time: {time.time() - start:.1f}s (approximately 6s)\n") + print(f" Time: {time.time() - start:.1f}s \n") # Parallel: .submit() returns immediately, tasks run concurrently print("Parallel execution with .submit():") start = time.time() futures = [slow_square.submit(i) for i in range(3)] - results = [f.result() for f in futures] # Wait for all results + results = [f.result() for f in futures] print(f" Results: {results}") print(f" Time: {time.time() - start:.1f}s (approximately 2s)") + + +@hog.harness() +def batch(n: int = 5): + """Run with: hog run parallel_execution.py batch""" + # .batch_submit() registers the function once and sends all tasks in a + # single API request, avoiding the per-task rate limits of a .submit() loop. + print("Batch remote submission:") + futures = slow_square.batch_submit( + endpoint="anvil", + args=[(i,) for i in range(n)], + ) + results = [f.result() for f in futures] + print(f" Results: {results}") # [0, 1, 4, 9, 16] + + # .batch_local() runs each task in its own subprocess in parallel + print("Batch local execution:") + futures = slow_square.batch_local( + args=[(i,) for i in range(n)], + executor_kwargs={"max_workers": 4}, + ) + results = [f.result() for f in futures] + print(f" Results: {results}") # [0, 1, 4, 9, 16] From a23f99d0b3fa542d0869c37406df375fa6910489 Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:18:01 -0500 Subject: [PATCH 09/10] example update to use local / batch_local --- docs/examples/parallel-execution.md | 19 +++++----- examples/parallel_execution.py | 56 ++++++++++++++++------------- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/docs/examples/parallel-execution.md b/docs/examples/parallel-execution.md index 880b4ca..87c04b3 100644 --- a/docs/examples/parallel-execution.md +++ b/docs/examples/parallel-execution.md @@ -16,7 +16,6 @@ This example demonstrates sequential execution with `.remote()`, parallel execut **Use `.submit()` when:** -- You have multiple independent tasks that can run concurrently - You don't care for the console display - You need access to the `GroundhogFuture` object @@ -174,23 +173,23 @@ print(future.function_name) # "slow_square" ## Running the Example ```bash -# Remote vs submit timing comparison +# sequential vs batch timing comparison (local methods) hog run examples/parallel_execution.py -# Batch submission and local parallel execution -hog run examples/parallel_execution.py batch +# .remote vs .submit vs .batch_submit +hog run examples/parallel_execution.py remote ``` Expected output from `main`: ``` -Sequential execution with .remote(): - Results: [0, 1, 4] - Time: 6.2s (approximately 6s) +Sequential execution with .local(): + Results: [0, 1, 4, 9, 16] + Time: 11.1s -Parallel execution with .submit(): - Results: [0, 1, 4] - Time: 2.1s (approximately 2s) +Parallel execution with .batch_local(): + Results: [0, 1, 4, 9, 16] + Time: 2.2s ``` ## Next Steps diff --git a/examples/parallel_execution.py b/examples/parallel_execution.py index 1562e83..ba4ce64 100644 --- a/examples/parallel_execution.py +++ b/examples/parallel_execution.py @@ -31,44 +31,50 @@ def slow_square(n: int) -> int: @hog.harness() -def main(): - """Run with: hog run parallel_execution.py""" +def main(n: int = 5): + """Run like: hog run parallel_execution.py -- --n=5""" import time # Sequential: each .remote() blocks until complete - print("Sequential execution with .remote():") + print("Sequential execution with .local():") start = time.time() - results = [slow_square.remote(i) for i in range(3)] + results = [slow_square.local(i) for i in range(n)] print(f" Results: {results}") print(f" Time: {time.time() - start:.1f}s \n") - # Parallel: .submit() returns immediately, tasks run concurrently - print("Parallel execution with .submit():") + print("Parallel execution with .batch_local():") start = time.time() - futures = [slow_square.submit(i) for i in range(3)] + futures = slow_square.batch_local(args=[(i,) for i in range(n)]) results = [f.result() for f in futures] print(f" Results: {results}") - print(f" Time: {time.time() - start:.1f}s (approximately 2s)") + print(f" Time: {time.time() - start:.1f}s ") @hog.harness() -def batch(n: int = 5): - """Run with: hog run parallel_execution.py batch""" - # .batch_submit() registers the function once and sends all tasks in a - # single API request, avoiding the per-task rate limits of a .submit() loop. - print("Batch remote submission:") - futures = slow_square.batch_submit( - endpoint="anvil", - args=[(i,) for i in range(n)], - ) +def remote(n: int = 5): + """Run like: hog run parallel_execution.py remote -- --n=5""" + import time + + args_list = [(i,) for i in range(n)] + # Sequential: each .remote() blocks until complete + print("Sequential execution with .remote():") + start = time.time() + results = [slow_square.remote(*args) for args in args_list] + print(f" Results: {results}") + print(f" Time: {time.time() - start:.1f}s \n") + + # Parallel: .submit() returns immediately, tasks run ~concurrently (N globus api calls) + print("Parallel execution with .submit():") + start = time.time() + futures = [slow_square.submit(*args) for args in args_list] results = [f.result() for f in futures] - print(f" Results: {results}") # [0, 1, 4, 9, 16] + print(f" Results: {results}") + print(f" Time: {time.time() - start:.1f}s ") - # .batch_local() runs each task in its own subprocess in parallel - print("Batch local execution:") - futures = slow_square.batch_local( - args=[(i,) for i in range(n)], - executor_kwargs={"max_workers": 4}, - ) + # Parallel: .batch_submit() returns immediately, tasks run concurrently (1 globus api call) + print("Parallel execution with .batch_submit():") + start = time.time() + futures = slow_square.batch_submit(args=args_list) results = [f.result() for f in futures] - print(f" Results: {results}") # [0, 1, 4, 9, 16] + print(f" Results: {results}") + print(f" Time: {time.time() - start:.1f}s ") From faef3f50417f240ec8f8b59b4e9512f5a1d847b8 Mon Sep 17 00:00:00 2001 From: Owen Price Skelly <21372141+OwenPriceSkelly@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:38:50 -0500 Subject: [PATCH 10/10] misc docs updates --- docs/concepts/functions-and-harnesses.md | 4 ++-- docs/examples/index.md | 2 +- docs/examples/local.md | 2 +- docs/getting-started/quickstart.md | 9 +++++---- docs/index.md | 2 +- src/groundhog_hpc/function.py | 5 +++-- 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/docs/concepts/functions-and-harnesses.md b/docs/concepts/functions-and-harnesses.md index 002c01f..b0bebd1 100644 --- a/docs/concepts/functions-and-harnesses.md +++ b/docs/concepts/functions-and-harnesses.md @@ -17,7 +17,7 @@ def train_model(dataset: str, epochs: int) -> dict: return {"accuracy": 0.95} ``` -Functions provide four execution modes: +Functions provide several execution modes: | Method | Where it runs | Behavior | |--------|---------------|----------| @@ -111,6 +111,6 @@ hog run script.py -- --epochs=20 # Runs main with epochs=20 ## Next steps -- **[Parallel Execution](../examples/parallel-execution.md)** - Use `.submit()` to run functions concurrently +- **[Parallel Execution](../examples/parallel-execution.md)** - Using `.batch_*` methods to run functions concurrently - **[Parameterized Harness Example](../examples/parameterized-harness.md)** - Complete example with CLI arguments - **[Remote Execution Flow](remote-execution.md)** - Understand what happens when you call `.remote()` diff --git a/docs/examples/index.md b/docs/examples/index.md index 572b8e8..b39e182 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -15,7 +15,7 @@ These examples cover the basics of using Groundhog: Examples showing how to handle typical workflows: -- **[Parallel Execution](parallel-execution.md)** - Using `.submit()` for concurrent remote execution +- **[Parallel Execution](parallel-execution.md)** - Using `.batch_submit()` or `.batch_local()` for concurrent execution - **[Parameterized Harnesses](parameterized-harness.md)** - Harnesses that accept CLI arguments for runtime configuration - **[Endpoint Configuration](configuration.md)** - How the configuration system merges settings from multiple sources (PEP 723, decorators, call-time overrides) - **[PyTorch from Custom Sources](pytorch_custom_index.md)** - Configuring uv to install packages from cluster-specific indexes, local paths, or internal mirrors diff --git a/docs/examples/local.md b/docs/examples/local.md index a06e919..411dc78 100644 --- a/docs/examples/local.md +++ b/docs/examples/local.md @@ -128,5 +128,5 @@ Using .local() - runs in subprocess with numpy installed: ## Next Steps -- **[Parallel Execution](parallel-execution.md)** - Run multiple functions concurrently with `.submit()` +- **[Parallel Execution](parallel-execution.md)** - Run multiple functions concurrently with `.batch_submit()` or `.batch_local()` - **[Configuration](configuration.md)** - Configure multiple endpoints diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 1bfa1d2..0d2ccd5 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -54,14 +54,15 @@ The comment block at the top uses [PEP 723](https://peps.python.org/pep-0723/) i - **`requires-python`**: Python version requirement for remote execution - **`dependencies`**: Python packages needed by your function (managed by uv) -- **`[tool.uv]`**: Optional configuration read by `uv run` when creating the ephemeral remote environment (see also: [full uv settings reference](https://docs.astral.sh/uv/reference/settings/)) -- **`[tool.hog.my-endpoint]`**: Endpoint configuration with HPC-specific settings like account, partition, walltime, etc. +- **`[tool.uv]`**: Optional configuration read by `uv venv` and `uv pip install` when creating the remote environment (see also: [full uv settings reference](https://docs.astral.sh/uv/reference/settings/)) +- **`[tool.hog.my-endpoint]`**: Endpoint configuration with HPC-specific settings like account, partition, walltime, etc. Recognized configuration options depend on the particular endpoint. + ### Functions and harnesses - **`@hog.function()`**: Decorates a Python function to make it executable remotely - **`@hog.harness()`**: Decorates an orchestrator function that coordinates remote calls. Harnesses can accept parameters passed as CLI arguments (see [Functions and Harnesses](../concepts/functions-and-harnesses.md)) -- **`.remote()`**: Executes the function remotely and blocks until complete (alternatively, use **`.submit()`** for async execution) +- **`.remote()`**: Executes the function remotely and blocks until complete (alternatively, use **`.submit()`** for async execution or **`batch_submit`** for many submissions) ## Add dependencies @@ -100,7 +101,7 @@ def compute_mean(data: list[float]) -> float: ``` !!! tip "Updating Python version" - You can also use `hog add` to update the Python version requirement: + You can also use `hog add` to update the Python version requirement, not just add dependencies: ```bash hog add hello.py --python 3.11 diff --git a/docs/index.md b/docs/index.md index 1218cf5..ac943b5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -191,7 +191,7 @@ hog run analysis.py ## What Makes Groundhog Different? **Environment and code stay coupled** -: Change your Python version or dependencies by editing the PEP 723 block in your script. The remote environment rebuilds automatically on the next run. +: Change your Python version or dependencies by editing the PEP 723 block in your script. The remote environment rebuilds automatically (if necessary) on the next run. **Globus Compute under the hood** : Built on [Globus Compute](https://www.globus.org/compute) for robust, secure HPC job submission. diff --git a/src/groundhog_hpc/function.py b/src/groundhog_hpc/function.py index 6c767af..52df2cc 100644 --- a/src/groundhog_hpc/function.py +++ b/src/groundhog_hpc/function.py @@ -415,8 +415,8 @@ def batch_submit( def batch_local( self, - args: list[tuple] = [], - kwargs: list[dict] = [], + args: list[tuple] | None = None, + kwargs: list[dict] | None = None, executor_kwargs: dict[str, Any] | None = None, ) -> list[GroundhogFuture]: """Execute the function locally in parallel subprocesses for each task. @@ -437,6 +437,7 @@ def batch_local( ModuleImportError: If called during module import ValueError: If both args and kwargs are empty """ + args, kwargs = args or [], kwargs or [] module = sys.modules.get(self._wrapped_function.__module__) if not getattr(module, "__groundhog_imported__", False): raise ModuleImportError(