diff --git a/docs/api/function.md b/docs/api/function.md index d3179c1..11913d4 100644 --- a/docs/api/function.md +++ b/docs/api/function.md @@ -14,6 +14,8 @@ The `Function` class wraps user functions decorated with `@hog.function()` to en - remote - submit - local + - batch_submit + - batch_local ## Method Class diff --git a/docs/concepts/functions-and-harnesses.md b/docs/concepts/functions-and-harnesses.md index 002c01f..b0bebd1 100644 --- a/docs/concepts/functions-and-harnesses.md +++ b/docs/concepts/functions-and-harnesses.md @@ -17,7 +17,7 @@ def train_model(dataset: str, epochs: int) -> dict: return {"accuracy": 0.95} ``` -Functions provide four execution modes: +Functions provide several execution modes: | Method | Where it runs | Behavior | |--------|---------------|----------| @@ -111,6 +111,6 @@ hog run script.py -- --epochs=20 # Runs main with epochs=20 ## Next steps -- **[Parallel Execution](../examples/parallel-execution.md)** - Use `.submit()` to run functions concurrently +- **[Parallel Execution](../examples/parallel-execution.md)** - Using `.batch_*` methods to run functions concurrently - **[Parameterized Harness Example](../examples/parameterized-harness.md)** - Complete example with CLI arguments - **[Remote Execution Flow](remote-execution.md)** - Understand what happens when you call `.remote()` diff --git a/docs/examples/index.md b/docs/examples/index.md index 572b8e8..b39e182 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -15,7 +15,7 @@ These examples cover the basics of using Groundhog: Examples showing how to handle typical workflows: -- **[Parallel Execution](parallel-execution.md)** - Using `.submit()` for concurrent remote execution +- **[Parallel Execution](parallel-execution.md)** - Using `.batch_submit()` or `.batch_local()` for concurrent execution - **[Parameterized Harnesses](parameterized-harness.md)** - Harnesses that accept CLI arguments for runtime configuration - **[Endpoint Configuration](configuration.md)** - How the configuration system merges settings from multiple sources (PEP 723, decorators, call-time overrides) - **[PyTorch from Custom Sources](pytorch_custom_index.md)** - Configuring uv to install packages from cluster-specific indexes, local paths, or internal mirrors diff --git a/docs/examples/local.md b/docs/examples/local.md index a06e919..411dc78 100644 --- a/docs/examples/local.md +++ b/docs/examples/local.md @@ -128,5 +128,5 @@ Using .local() - runs in subprocess with numpy installed: ## Next Steps -- **[Parallel Execution](parallel-execution.md)** - Run multiple functions concurrently with `.submit()` +- **[Parallel Execution](parallel-execution.md)** - Run multiple functions concurrently with `.batch_submit()` or `.batch_local()` - **[Configuration](configuration.md)** - Configure multiple endpoints diff --git a/docs/examples/parallel-execution.md b/docs/examples/parallel-execution.md index 14738eb..87c04b3 100644 --- a/docs/examples/parallel-execution.md +++ b/docs/examples/parallel-execution.md @@ -1,6 +1,6 @@ # Parallel Execution -This example demonstrates the difference between sequential execution with `.remote()` and parallel execution with `.submit()`. +This example demonstrates sequential execution with `.remote()`, parallel execution with `.submit()`, and batch execution with `.batch_submit()` and `.batch_local()`. ## When to Use Each Method @@ -16,11 +16,21 @@ This example demonstrates the difference between sequential execution with `.rem **Use `.submit()` when:** -- You have multiple independent tasks that can run concurrently - You don't care for the console display - You need access to the `GroundhogFuture` object -## Full Example +**Use `.batch_submit()` when:** + +- You're submitting many tasks to the same remote endpoint +- You want to avoid Globus Compute rate limits (batching is one API call instead of N) +- All tasks use the same function with different arguments + +**Use `.batch_local()` when:** + +- You want to run many tasks in parallel locally +- You want immediate `GroundhogFuture`s instead of `.local()`'s blocking behavior + +## Example: Remote vs Submit ```python title="parallel_execution.py" # /// script @@ -28,7 +38,7 @@ This example demonstrates the difference between sequential execution with `.rem # dependencies = [] # # [tool.uv] -# exclude-newer = "2025-12-02T19:48:40Z" +# exclude-newer = "2026-03-06T00:00:00Z" # # [tool.hog.anvil] # endpoint = "5aafb4c1-27b2-40d8-a038-a0277611868f" @@ -66,6 +76,28 @@ def main(): results = [f.result() for f in futures] # (3)! print(f" Results: {results}") print(f" Time: {time.time() - start:.1f}s (approximately 2s)") + + +@hog.harness() +def batch(): + """Run with: hog run parallel_execution.py batch""" + # .batch_submit() registers the function once and sends all tasks in a + # single API request, avoiding the per-task rate limits of a .submit() loop. + print("Batch remote submission:") + futures = slow_square.batch_submit( + args=[(0,), (1,), (2,), (3,), (4,)], + ) + results = [f.result() for f in futures] + print(f" Results: {results}") # [0, 1, 4, 9, 16] + + # .batch_local() runs each task in its own subprocess in parallel. + print("Batch local execution:") + futures = slow_square.batch_local( + args=[(0,), (1,), (2,), (3,), (4,)], + executor_kwargs={"max_workers": 4}, + ) + results = [f.result() for f in futures] + print(f" Results: {results}") # [0, 1, 4, 9, 16] ``` 1. `.remote()` blocks until the function completes. Each call waits for the previous one to finish. Total time: 3 tasks x 2 seconds = ~6 seconds. @@ -74,45 +106,90 @@ def main(): 3. Calling `.result()` on each future blocks until that task completes. Since all tasks run in parallel, total time is ~2 seconds. + +## Example: Batching Locally / Remotely + +A loop of `.submit()` calls makes one API request per task and can hit Globus Compute rate limits at large N. `.batch_submit()` registers the function once and sends all tasks in a single request. + +```python +# Instead of this (N separate API calls): +futures = [slow_square.submit(i) for i in range(5)] + +# Use batch_submit (one API call): +futures = slow_square.batch_submit( + args=[(0,), (1,), (2,), (3,), (4,)], # (1)! +) +results = [f.result() for f in futures] +# [0, 1, 4, 9, 16] +``` + +1. Each tuple is unpacked as positional arguments for one task. Pass `kwargs=[...]` alongside `args` to mix positional and keyword arguments — when the two lists have different lengths, the shorter one fills with `()` or `{}`. + +`.batch_local()` runs each task in its own subprocess with an isolated temporary directory: + +```python +futures = slow_square.batch_local( + args=[(0,), (1,), (2,), (3,), (4,)], + executor_kwargs={"max_workers": 4}, # (1)! +) +results = [f.result() for f in futures] +# [0, 1, 4, 9, 16] +``` + +1. `executor_kwargs` is forwarded directly to `ThreadPoolExecutor`. Omit it to use the default worker count. + ## Working with GroundhogFutures +`.submit()` and both batch methods return `GroundhogFuture` objects. They behave like standard `concurrent.futures.Future` objects, with additional Groundhog-specific properties. + ```python future = slow_square.submit(5) +# Get the deserialized return value (blocks until ready) +result = future.result() +result = future.result(timeout=10) # Raises TimeoutError if not ready + # Check if done (non-blocking) if future.done(): print("Task completed!") -# Get the result (blocks until ready) -result = future.result() - -# Get result with timeout -result = future.result(timeout=10) # Raises TimeoutError if not ready - # Cancel a pending task future.cancel() -# Inspect the underlying ShellResult +# Inspect raw shell execution metadata print(future.shell_result.returncode) print(future.shell_result.stderr) + +# Capture stdout from print() calls inside the remote function +if future.user_stdout: + print(future.user_stdout) + +# Inspect the resolved configuration that was actually passed to the endpoint +print(future.user_endpoint_config) # {"account": "...", "partition": "..."} +print(future.task_id) # Globus Compute task ID +print(future.function_name) # "slow_square" ``` ## Running the Example ```bash +# sequential vs batch timing comparison (local methods) hog run examples/parallel_execution.py + +# .remote vs .submit vs .batch_submit +hog run examples/parallel_execution.py remote ``` -Expected output: +Expected output from `main`: ``` -Sequential execution with .remote(): - Results: [0, 1, 4] - Time: 6.2s (approximately 6s) +Sequential execution with .local(): + Results: [0, 1, 4, 9, 16] + Time: 11.1s -Parallel execution with .submit(): - Results: [0, 1, 4] - Time: 2.1s (approximately 2s) +Parallel execution with .batch_local(): + Results: [0, 1, 4, 9, 16] + Time: 2.2s ``` ## Next Steps diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index 1bfa1d2..0d2ccd5 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -54,14 +54,15 @@ The comment block at the top uses [PEP 723](https://peps.python.org/pep-0723/) i - **`requires-python`**: Python version requirement for remote execution - **`dependencies`**: Python packages needed by your function (managed by uv) -- **`[tool.uv]`**: Optional configuration read by `uv run` when creating the ephemeral remote environment (see also: [full uv settings reference](https://docs.astral.sh/uv/reference/settings/)) -- **`[tool.hog.my-endpoint]`**: Endpoint configuration with HPC-specific settings like account, partition, walltime, etc. +- **`[tool.uv]`**: Optional configuration read by `uv venv` and `uv pip install` when creating the remote environment (see also: [full uv settings reference](https://docs.astral.sh/uv/reference/settings/)) +- **`[tool.hog.my-endpoint]`**: Endpoint configuration with HPC-specific settings like account, partition, walltime, etc. Recognized configuration options depend on the particular endpoint. + ### Functions and harnesses - **`@hog.function()`**: Decorates a Python function to make it executable remotely - **`@hog.harness()`**: Decorates an orchestrator function that coordinates remote calls. Harnesses can accept parameters passed as CLI arguments (see [Functions and Harnesses](../concepts/functions-and-harnesses.md)) -- **`.remote()`**: Executes the function remotely and blocks until complete (alternatively, use **`.submit()`** for async execution) +- **`.remote()`**: Executes the function remotely and blocks until complete (alternatively, use **`.submit()`** for async execution or **`batch_submit`** for many submissions) ## Add dependencies @@ -100,7 +101,7 @@ def compute_mean(data: list[float]) -> float: ``` !!! tip "Updating Python version" - You can also use `hog add` to update the Python version requirement: + You can also use `hog add` to update the Python version requirement, not just add dependencies: ```bash hog add hello.py --python 3.11 diff --git a/docs/index.md b/docs/index.md index 1218cf5..ac943b5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -191,7 +191,7 @@ hog run analysis.py ## What Makes Groundhog Different? **Environment and code stay coupled** -: Change your Python version or dependencies by editing the PEP 723 block in your script. The remote environment rebuilds automatically on the next run. +: Change your Python version or dependencies by editing the PEP 723 block in your script. The remote environment rebuilds automatically (if necessary) on the next run. **Globus Compute under the hood** : Built on [Globus Compute](https://www.globus.org/compute) for robust, secure HPC job submission. diff --git a/examples/parallel_execution.py b/examples/parallel_execution.py index 47d08dc..ba4ce64 100644 --- a/examples/parallel_execution.py +++ b/examples/parallel_execution.py @@ -3,18 +3,19 @@ # dependencies = [] # # [tool.uv] -# exclude-newer = "2025-12-02T19:48:40Z" +# exclude-newer = "2026-03-06T00:00:00Z" # # [tool.hog.anvil] # endpoint = "5aafb4c1-27b2-40d8-a038-a0277611868f" # account = "cis250461" -# requirements = "" # /// """ -Example showing parallel execution with .submit() vs sequential with .remote(). +Example showing parallel and batch execution patterns. Use .remote() when you want to wait for each result before continuing. Use .submit() when you want to run multiple tasks in parallel. +Use .batch_submit() to submit many tasks without hitting rate limits. +Use .batch_local() for parallel local execution on the login node. """ import groundhog_hpc as hog @@ -30,21 +31,50 @@ def slow_square(n: int) -> int: @hog.harness() -def main(): - """Run with: hog run parallel_execution.py""" +def main(n: int = 5): + """Run like: hog run parallel_execution.py -- --n=5""" import time + # Sequential: each .remote() blocks until complete + print("Sequential execution with .local():") + start = time.time() + results = [slow_square.local(i) for i in range(n)] + print(f" Results: {results}") + print(f" Time: {time.time() - start:.1f}s \n") + + print("Parallel execution with .batch_local():") + start = time.time() + futures = slow_square.batch_local(args=[(i,) for i in range(n)]) + results = [f.result() for f in futures] + print(f" Results: {results}") + print(f" Time: {time.time() - start:.1f}s ") + + +@hog.harness() +def remote(n: int = 5): + """Run like: hog run parallel_execution.py remote -- --n=5""" + import time + + args_list = [(i,) for i in range(n)] # Sequential: each .remote() blocks until complete print("Sequential execution with .remote():") start = time.time() - results = [slow_square.remote(i) for i in range(3)] + results = [slow_square.remote(*args) for args in args_list] print(f" Results: {results}") - print(f" Time: {time.time() - start:.1f}s (approximately 6s)\n") + print(f" Time: {time.time() - start:.1f}s \n") - # Parallel: .submit() returns immediately, tasks run concurrently + # Parallel: .submit() returns immediately, tasks run ~concurrently (N globus api calls) print("Parallel execution with .submit():") start = time.time() - futures = [slow_square.submit(i) for i in range(3)] - results = [f.result() for f in futures] # Wait for all results + futures = [slow_square.submit(*args) for args in args_list] + results = [f.result() for f in futures] + print(f" Results: {results}") + print(f" Time: {time.time() - start:.1f}s ") + + # Parallel: .batch_submit() returns immediately, tasks run concurrently (1 globus api call) + print("Parallel execution with .batch_submit():") + start = time.time() + futures = slow_square.batch_submit(args=args_list) + results = [f.result() for f in futures] print(f" Results: {results}") - print(f" Time: {time.time() - start:.1f}s (approximately 2s)") + print(f" Time: {time.time() - start:.1f}s ") diff --git a/src/groundhog_hpc/compute.py b/src/groundhog_hpc/compute.py index 8e3bb14..03d0d9d 100644 --- a/src/groundhog_hpc/compute.py +++ b/src/groundhog_hpc/compute.py @@ -1,19 +1,20 @@ """Globus Compute execution interface. -This module provides functions for converting user scripts into Globus Compute -ShellFunctions, registering them, and submitting them for execution on remote +This module provides functions for building Globus Compute ShellFunctions from +pre-rendered shell command strings and submitting them for execution on remote endpoints. """ import logging import os +import threading import warnings +from concurrent.futures import Future as ConcurrentFuture from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeVar from uuid import UUID from groundhog_hpc.future import GroundhogFuture -from groundhog_hpc.templating import template_shell_command logger = logging.getLogger(__name__) @@ -46,43 +47,43 @@ def _get_compute_client() -> Client: return gc.Client() -def script_to_submittable( - script_path: str, - function_name: str, - payload: str, +def build_shell_function( + shell_command: str, + name: str, walltime: int | float | None = None, ) -> ShellFunction: - """Convert a user script and function name into a Globus Compute ShellFunction. + """Create a Globus Compute ShellFunction from a pre-rendered shell command string. Args: - script_path: Path to the Python script containing the function - function_name: Name of the function to execute remotely - payload: Serialized arguments string - walltime: Optional maximum execution time in seconds for ShellFunction timeout + shell_command: The shell command string (may contain {payload} placeholder) + name: Function name used as the ShellFunction name (dots replaced with underscores) + walltime: Optional maximum execution time in seconds Returns: A ShellFunction ready to be submitted to a Globus Compute executor """ import globus_compute_sdk as gc - shell_command = template_shell_command(script_path, function_name, payload) - shell_function = gc.ShellFunction( - shell_command, name=function_name.replace(".", "_"), walltime=walltime + return gc.ShellFunction( + shell_command, name=name.replace(".", "_"), walltime=walltime ) - return shell_function def submit_to_executor( endpoint: UUID, user_endpoint_config: dict[str, Any], shell_function: ShellFunction, + payload: str, + executor_kwargs: dict[str, Any] | None = None, ) -> GroundhogFuture: """Submit a ShellFunction to a Globus Compute endpoint for execution. Args: endpoint: UUID of the Globus Compute endpoint user_endpoint_config: Configuration dict for the endpoint (e.g., worker_init, walltime) - shell_function: The ShellFunction to execute (with payload already templated in) + shell_function: The parameterized ShellFunction to execute + payload: Serialized arguments string, substituted into the {payload} placeholder + executor_kwargs: Extra keyword arguments forwarded directly to gc.Executor constructor Returns: A GroundhogFuture that will contain the deserialized result @@ -101,12 +102,14 @@ def submit_to_executor( config = {k: v for k, v in config.items() if k not in unexpected_keys} logger.debug(f"Creating Globus Compute executor for endpoint {endpoint}") - with gc.Executor(endpoint, user_endpoint_config=config) as executor: + with gc.Executor( + endpoint, user_endpoint_config=config, **(executor_kwargs or {}) + ) as executor: func_name = getattr( shell_function, "__name__", getattr(shell_function, "name", "unknown") ) logger.info(f"Submitting function '{func_name}' to endpoint '{endpoint}'") - future = executor.submit(shell_function) + future = executor.submit(shell_function, payload=payload) task_id = getattr(future, "task_id", None) if task_id: logger.info(f"Task submitted with ID: {task_id}") @@ -114,6 +117,104 @@ def submit_to_executor( return deserializing_future +def submit_batch( + endpoint: UUID, + user_endpoint_config: dict[str, Any], + shell_function: ShellFunction, + payloads: list[str], +) -> list[GroundhogFuture]: + """Submit a parameterized ShellFunction as a batch of tasks to a Globus Compute endpoint. + + Registers the ShellFunction once, then submits all payloads as a single batch + request, avoiding per-task API calls that can hit rate limits. + + Args: + endpoint: UUID of the Globus Compute endpoint + user_endpoint_config: Configuration dict for the endpoint + shell_function: The parameterized ShellFunction (registered once for all tasks) + payloads: List of serialized argument strings, one per task + + Returns: + A list of GroundhogFutures in the same order as payloads + """ + client = _get_compute_client() + + config = user_endpoint_config.copy() + if schema := get_endpoint_schema(endpoint): + expected_keys = set(schema.get("properties", {}).keys()) + unexpected_keys = set(config.keys()) - expected_keys + if unexpected_keys: + logger.debug( + f"Filtering unexpected config keys for endpoint {endpoint}: {unexpected_keys}" + ) + config = {k: v for k, v in config.items() if k not in unexpected_keys} + + func_name = getattr(shell_function, "__name__", "unknown") + function_id = client.register_function(shell_function) + logger.info( + f"Registered '{func_name}' for batch submission, function_id={function_id}" + ) + + batch = client.create_batch(user_endpoint_config=config) + for payload in payloads: + batch.add(function_id, kwargs={"payload": payload}) + + response = client.batch_run(endpoint, batch) + task_ids: list[str] = response["tasks"][function_id] + logger.info(f"Batch submitted: {len(task_ids)} tasks to endpoint '{endpoint}'") + + task_id_to_future: dict[str, ConcurrentFuture] = { + tid: ConcurrentFuture() for tid in task_ids + } + + thread = threading.Thread( + target=_poll_batch_results, + args=(dict(task_id_to_future), client), + daemon=True, + ) + thread.start() + + futures = [] + for task_id in task_ids: + gf = GroundhogFuture(task_id_to_future[task_id]) + gf._task_id = task_id + futures.append(gf) + + return futures + + +def _poll_batch_results( + task_id_to_future: dict[str, ConcurrentFuture], + client: Client, + poll_interval: float = 1.0, +) -> None: + """Background thread: poll Globus Compute until all batch tasks are resolved.""" + import time + + pending = dict(task_id_to_future) + + while pending: + results = client.get_batch_result(list(pending.keys())) + for task_id, status in results.items(): + if status.get("pending", True): + continue + fut = pending.pop(task_id) + try: + if "result" in status: + fut.set_result(status["result"]) + else: + try: + status["exception"].reraise() + except Exception as e: + fut.set_exception(e) + except Exception as e: + if not fut.done(): + fut.set_exception(e) + + if pending: + time.sleep(poll_interval) + + def get_task_status(task_id: str | UUID | None) -> dict[str, Any]: """Get the full task status response from Globus Compute. diff --git a/src/groundhog_hpc/function.py b/src/groundhog_hpc/function.py index 97f641f..52df2cc 100644 --- a/src/groundhog_hpc/function.py +++ b/src/groundhog_hpc/function.py @@ -10,16 +10,19 @@ """ import inspect +import itertools import logging import os +import subprocess import sys import tempfile +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from types import FunctionType from typing import TYPE_CHECKING, Any, TypeVar from uuid import UUID -from groundhog_hpc.compute import script_to_submittable, submit_to_executor +from groundhog_hpc.compute import build_shell_function, submit_batch, submit_to_executor from groundhog_hpc.configuration.resolver import ConfigResolver from groundhog_hpc.console import display_task_status from groundhog_hpc.errors import ( @@ -29,6 +32,7 @@ ) from groundhog_hpc.future import GroundhogFuture from groundhog_hpc.serialization import deserialize_stdout, serialize +from groundhog_hpc.templating import template_shell_command from groundhog_hpc.utils import prefix_output logger = logging.getLogger(__name__) @@ -43,6 +47,35 @@ ShellResult = TypeVar("ShellResult") +def _run_shell_locally(cmd_template: str, payload: str, tmpdir: str) -> ShellResult: + """Execute a parameterized shell command locally. + + Injects GC_TASK_SANDBOX_DIR into the subprocess environment without + mutating os.environ, making concurrent calls thread-safe. + """ + import globus_compute_sdk as gc + + env = {**os.environ, "GC_TASK_SANDBOX_DIR": tmpdir} + cmd = cmd_template.format(payload=payload) + proc = subprocess.run( + cmd, + shell=True, + executable="/bin/bash", + capture_output=True, + text=True, + env=env, + ) + return gc.ShellResult( + cmd=cmd, + stdout=proc.stdout, + stderr=proc.stderr, + returncode=proc.returncode, + exception_name="subprocess.CalledProcessError" + if proc.returncode != 0 + else None, + ) + + class Function: """Wrapper that enables a Python function to be executed remotely on Globus Compute. @@ -79,12 +112,18 @@ def __init__( # ShellFunction walltime - always None here to prevent conflicts with a # 'walltime' endpoint config, but the attribute exists as an escape - # hatch if users need to set it after the function's been created + # hatch if users need to set it after the function's been created. + # NOTE: walltime must be set before the first .submit() or .local() call; + # changing it afterwards has no effect because shell_function is cached. self.walltime: int | float | None = None self._wrapped_function: FunctionType = func self._config_resolver: ConfigResolver | None = None + # Cached parameterized shell command and ShellFunction (built once, reused per instance) + self._shell_command: str | None = None + self._shell_function: ShellFunction | None = None + def __call__(self, *args: Any, **kwargs: Any) -> Any: """Execute the function locally (not remotely). @@ -110,6 +149,7 @@ def submit( *args: Any, endpoint: str | None = None, user_endpoint_config: dict[str, Any] | None = None, + executor_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> GroundhogFuture: """Submit the function for asynchronous remote execution. @@ -119,6 +159,7 @@ def submit( endpoint: Globus Compute endpoint UUID (or named endpoint from `[tool.hog.]` PEP 723 metadata). Replaces decorator default. user_endpoint_config: Endpoint configuration dict (merged with decorator default) + executor_kwargs: Keyword arguments forwarded to Globus Compute Executor **kwargs: Keyword arguments to pass to the function Returns: @@ -177,14 +218,13 @@ def submit( f"Serializing {len(args)} args and {len(kwargs)} kwargs for '{self.name}'" ) payload = serialize((args, kwargs), use_proxy=False, proxy_threshold_mb=None) - shell_function = script_to_submittable( - self.script_path, self.name, payload, walltime=self.walltime - ) future: GroundhogFuture = submit_to_executor( UUID(endpoint), user_endpoint_config=config, - shell_function=shell_function, + shell_function=self.shell_function, + payload=payload, + executor_kwargs=executor_kwargs, ) future.endpoint = endpoint future.user_endpoint_config = config @@ -196,6 +236,7 @@ def remote( *args: Any, endpoint: str | None = None, user_endpoint_config: dict[str, Any] | None = None, + executor_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Execute the function remotely and block until completion. @@ -208,6 +249,7 @@ def remote( endpoint: Globus Compute endpoint UUID (or named endpoint from `[tool.hog.]` PEP 723 metadata). Replaces decorator default. user_endpoint_config: Endpoint configuration dict (merged with decorator default) + executor_kwargs: Keyword arguments forwarded to Globus Compute Executor **kwargs: Keyword arguments to pass to the function Returns: @@ -224,6 +266,7 @@ def remote( *args, endpoint=endpoint, user_endpoint_config=user_endpoint_config, + executor_kwargs=executor_kwargs, **kwargs, ) display_task_status(future) @@ -260,18 +303,10 @@ def local(self, *args: Any, **kwargs: Any) -> Any: logger.debug(f"Executing function '{self.name}' in local subprocess") with prefix_output(prefix="[local]", prefix_color="blue"): - # Create ShellFunction just like we do for remote execution payload = serialize((args, kwargs), proxy_threshold_mb=1.0) - shell_function = script_to_submittable(self.script_path, self.name, payload) with tempfile.TemporaryDirectory() as tmpdir: - # set sandbox dir for ShellFunction to use - if "GC_TASK_SANDBOX_DIR" not in os.environ: - os.environ["GC_TASK_SANDBOX_DIR"] = tmpdir - - # just __call__ ShellFunction to execute the command - result = shell_function() - assert not isinstance(result, dict) + result = _run_shell_locally(self.shell_function.cmd, payload, tmpdir) if result.returncode != 0: logger.error( @@ -305,6 +340,156 @@ def local(self, *args: Any, **kwargs: Any) -> Any: print(user_stdout, file=sys.stdout) return deserialized_result + def batch_submit( + self, + args: list[tuple] | None = None, + kwargs: list[dict] | None = None, + endpoint: str | None = None, + user_endpoint_config: dict[str, Any] | None = None, + ) -> list[GroundhogFuture]: + """Submit the function for asynchronous remote execution as a batch. + + Submits all tasks as a single Globus Compute batch request, avoiding + per-task API calls that can hit rate limits. + + Args: + args: List of positional-argument tuples, one per task + kwargs: List of keyword-argument dicts, one per task + endpoint: Globus Compute endpoint UUID or named endpoint + user_endpoint_config: Endpoint configuration dict (merged with decorator default) + + Returns: + A list of GroundhogFutures in the same order as the input tasks + + Raises: + ModuleImportError: If called during module import + ValueError: If both args and kwargs are empty + """ + args = args or [] + kwargs = kwargs or [] + module = sys.modules.get(self._wrapped_function.__module__) + if not getattr(module, "__groundhog_imported__", False): + raise ModuleImportError( + self._wrapped_function.__name__, + "batch_submit", + self._wrapped_function.__module__, + ) + + if max(len(args), len(kwargs)) == 0: + raise ValueError( + "batch_submit requires at least one task: args and kwargs are both empty" + ) + + endpoint = endpoint or self.endpoint + decorator_config = self.default_user_endpoint_config.copy() + call_time_config = user_endpoint_config.copy() if user_endpoint_config else {} + config = self.config_resolver.resolve( + endpoint_name=endpoint or "", + decorator_config=decorator_config, + call_time_config=call_time_config, + ) + if "endpoint" in config: + endpoint = config.pop("endpoint") + if not endpoint: + available_endpoints = self._get_available_endpoints_from_pep723() + if available_endpoints: + endpoints_str = ", ".join(f"'{e}'" for e in available_endpoints) + raise ValueError( + f"No endpoint specified. Available endpoints found in config: {endpoints_str}." + ) + raise ValueError("No endpoint specified") + + payloads = [] + for a, kw in itertools.zip_longest(args, kwargs, fillvalue=None): + a = a if a is not None else () + kw = kw if kw is not None else {} + payloads.append( + serialize((a, kw), use_proxy=False, proxy_threshold_mb=None) + ) + + futures = submit_batch(UUID(endpoint), config, self.shell_function, payloads) + for future in futures: + future.function_name = self.name + future.user_endpoint_config = config + return futures + + def batch_local( + self, + args: list[tuple] | None = None, + kwargs: list[dict] | None = None, + executor_kwargs: dict[str, Any] | None = None, + ) -> list[GroundhogFuture]: + """Execute the function locally in parallel subprocesses for each task. + + Submits all tasks to a ThreadPoolExecutor immediately and returns futures + without waiting for completion. Each task runs in its own subprocess with + an isolated temporary directory. + + Args: + args: List of positional-argument tuples, one per task + kwargs: List of keyword-argument dicts, one per task + executor_kwargs: Keyword arguments forwarded to ThreadPoolExecutor + + Returns: + A list of GroundhogFutures in the same order as the input tasks + + Raises: + ModuleImportError: If called during module import + ValueError: If both args and kwargs are empty + """ + args, kwargs = args or [], kwargs or [] + module = sys.modules.get(self._wrapped_function.__module__) + if not getattr(module, "__groundhog_imported__", False): + raise ModuleImportError( + self._wrapped_function.__name__, + "batch_local", + self._wrapped_function.__module__, + ) + + if max(len(args), len(kwargs)) == 0: + raise ValueError( + "batch_local requires at least one task: args and kwargs are both empty" + ) + + payloads = [] + for a, kw in itertools.zip_longest(args, kwargs, fillvalue=None): + a = a if a is not None else () + kw = kw if kw is not None else {} + payloads.append(serialize((a, kw), proxy_threshold_mb=1.0)) + + cmd_template = self.shell_function.cmd + + def _worker(payload: str) -> ShellResult: + with tempfile.TemporaryDirectory() as tmpdir: + return _run_shell_locally(cmd_template, payload, tmpdir) + + executor = ThreadPoolExecutor(**(executor_kwargs or {})) + return [GroundhogFuture(executor.submit(_worker, p)) for p in payloads] + + @property + def shell_command(self) -> str: + """Parameterized shell command string with a {payload} placeholder. + + Generated once from the script file and cached. The same command string + is reused for all invocations of this function. + """ + if self._shell_command is None: + self._shell_command = template_shell_command(self.script_path, self.name) + return self._shell_command + + @property + def shell_function(self) -> ShellFunction: + """Cached Globus Compute ShellFunction built from the parameterized shell command. + + Created once and reused for all .submit() and .local() calls, so the + same ShellFunction object handles concurrent invocations. + """ + if self._shell_function is None: + self._shell_function = build_shell_function( + self.shell_command, self.name, walltime=self.walltime + ) + return self._shell_function + @property def script_path(self) -> str: """Get the script path for this function. diff --git a/src/groundhog_hpc/future.py b/src/groundhog_hpc/future.py index d5597fc..2f09750 100644 --- a/src/groundhog_hpc/future.py +++ b/src/groundhog_hpc/future.py @@ -98,7 +98,7 @@ def task_id(self) -> str | None: Returns the task ID from the underlying Globus Compute future, which may not be populated immediately. """ - return self._original_future.task_id # type: ignore[attr-defined] + return self._task_id or getattr(self._original_future, "task_id", None) @property def endpoint(self) -> str | None: diff --git a/src/groundhog_hpc/templates/shell_command.sh.jinja b/src/groundhog_hpc/templates/shell_command.sh.jinja index 3a81241..c3e08e8 100644 --- a/src/groundhog_hpc/templates/shell_command.sh.jinja +++ b/src/groundhog_hpc/templates/shell_command.sh.jinja @@ -1,7 +1,7 @@ set -euo pipefail -# Cleanup temporary files on exit (env is preserved for reuse) -trap 'rm -f {{ user_script_name }}.py {{ runner_name }}.py {{ script_name }}.in {{ script_name }}.out' EXIT +TASK_DIR=$(mktemp -d) +trap 'rm -rf "$TASK_DIR"' EXIT if command -v uv &> /dev/null; then UV_BIN=$(command -v uv) @@ -51,16 +51,16 @@ export GROUNDHOG_LOG_LEVEL="${{GROUNDHOG_LOG_LEVEL:-WARNING}}" {% endraw %} {% endif %} -cat > {{ user_script_name }}.py << 'USER_SCRIPT_EOF' +cat > "$TASK_DIR/user_script.py" << 'USER_SCRIPT_EOF' {{ user_script_contents | escape_braces }} USER_SCRIPT_EOF -cat > {{ runner_name }}.py << 'RUNNER_EOF' +cat > "$TASK_DIR/runner.py" << 'RUNNER_EOF' {{ runner_contents | escape_braces }} RUNNER_EOF -cat > {{ script_name }}.in << 'PAYLOAD_EOF' -{{ payload }} +cat > "$TASK_DIR/payload.in" << 'PAYLOAD_EOF' +{payload} PAYLOAD_EOF # Check if environment exists; create if not @@ -115,7 +115,8 @@ META_EOF fi # Run using the cached environment's Python directly (bypasses uv resolution) -"$ENV_DIR/bin/python" {{ runner_name }}.py +cd "$TASK_DIR" +"$ENV_DIR/bin/python" runner.py echo "__GROUNDHOG_RESULT__" -cat {{ script_name }}.out +cat payload.out diff --git a/src/groundhog_hpc/templating.py b/src/groundhog_hpc/templating.py index 5d0eb16..020483c 100644 --- a/src/groundhog_hpc/templating.py +++ b/src/groundhog_hpc/templating.py @@ -11,7 +11,6 @@ import logging import os import re -import uuid from datetime import datetime, timezone from hashlib import sha1 from pathlib import Path @@ -63,32 +62,34 @@ def compute_env_hash(metadata: Pep723Metadata) -> str: return sha1(canonical.encode("utf-8")).hexdigest()[:8] -def template_shell_command(script_path: str, function_name: str, payload: str) -> str: - """Generate a shell command to execute a user function on a remote endpoint. +def template_shell_command(script_path: str, function_name: str) -> str: + """Generate a parameterized shell command for remote execution. - The generated shell command: - - Creates a runner script that imports the user script as a module - - Writes the user script to a file (unmodified) - - Sets up input/output files for serialized data - - Executes the runner with uv for dependency management + The payload is NOT baked into the command. Instead, a {payload} format + placeholder is left so a single ShellFunction can be reused for all + invocations of the same function: + + shell_function(payload=serialized_payload) + + which calls cmd.format(payload=serialized_payload) before execution. + + File isolation is provided by mktemp -d per invocation so concurrent tasks + on the same node don't collide. Args: script_path: Path to the user's Python script function_name: Name of the function to execute - payload: Serialized arguments string Returns: - A fully-formed shell command string ready to be executed via Globus - Compute or local subprocess + A shell command string containing a {payload} format placeholder """ logger.debug( - f"Templating shell command for function '{function_name}' in script '{script_path}'" + f"Templating shell command for function '{function_name}' in '{script_path}'" ) with open(script_path, "r") as f_in: user_script = f_in.read() - # Extract PEP 723 metadata for the runner metadata = read_pep723(user_script) pep723_metadata = write_pep723(metadata) if metadata else "" @@ -101,18 +102,6 @@ def template_shell_command(script_path: str, function_name: str, payload: str) - ) env_hash = _script_hash_prefix(user_script) - script_hash = _script_hash_prefix(user_script) - script_basename = _extract_script_basename(script_path) - random_suffix = uuid.uuid4().hex[:8] - script_name = f"{script_basename}-{script_hash}-{random_suffix}" - - # Generate names for the user script and runner - user_script_name = script_name - runner_name = f"{script_name}_runner" - user_script_path_remote = f"{user_script_name}.py" - payload_path = f"{script_name}.in" - outfile_path = f"{script_name}.out" - version_spec = get_groundhog_version_spec() logger.debug(f"Using groundhog version spec: {version_spec}") semver_match = re.search(r"==([0-9][^\s]*)", version_spec) @@ -124,27 +113,22 @@ def template_shell_command(script_path: str, function_name: str, payload: str) - else: groundhog_version = _script_hash_prefix(version_spec) - # Generate timestamp for groundhog-hpc exclude-newer override - # This allows groundhog to bypass user's exclude-newer restrictions groundhog_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - # Load runner template templates_dir = Path(__file__).parent / "templates" jinja_env = Environment(loader=FileSystemLoader(templates_dir)) jinja_env.filters["escape_braces"] = escape_braces runner_template = jinja_env.get_template("groundhog_run.py.jinja") - # Render runner script runner_contents = runner_template.render( pep723_metadata=pep723_metadata, - script_path=user_script_path_remote, + script_path="user_script.py", function_name=function_name, - payload_path=payload_path, - outfile_path=outfile_path, + payload_path="payload.in", + outfile_path="payload.out", module_name=path_to_module_name(script_path), ) - # Read local log level (None if not set) local_log_level = os.getenv("GROUNDHOG_LOG_LEVEL") if local_log_level: local_log_level = local_log_level.upper() @@ -152,16 +136,11 @@ def template_shell_command(script_path: str, function_name: str, payload: str) - uv_config_toml = _serialize_uv_toml(metadata) - # Render shell command shell_template = jinja_env.get_template("shell_command.sh.jinja") shell_command_string = shell_template.render( - user_script_name=user_script_name, user_script_contents=user_script, - runner_name=runner_name, runner_contents=runner_contents, - script_name=script_name, version_spec=version_spec, - payload=payload, log_level=local_log_level, groundhog_timestamp=groundhog_timestamp, env_hash=env_hash, @@ -194,7 +173,3 @@ def _serialize_uv_toml(metadata: Pep723Metadata | None) -> str: def _script_hash_prefix(contents: str, length: int = 8) -> str: return str(sha1(bytes(contents, "utf-8")).hexdigest()[:length]) - - -def _extract_script_basename(script_path: str) -> str: - return Path(script_path).stem diff --git a/tests/conftest.py b/tests/conftest.py index bfba954..f95d35a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,12 @@ import os import sys -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from groundhog_hpc.function import Function + @pytest.fixture(scope="session", autouse=True) def configure_rich_for_ci(): @@ -162,7 +164,7 @@ def mock_submission_stack(): Provides access to all mocks and their return values for assertions. Returns dict with: - - script_to_submittable: Mock for script_to_submittable function + - shell_function_prop: PropertyMock for Function.shell_function property - submit_to_executor: Mock for submit_to_executor function - get_endpoint_schema: Mock for get_endpoint_schema function - shell_function: The mock ShellFunction instance @@ -177,9 +179,12 @@ def test_something(mock_submission_stack): mock_shell_func = MagicMock() mock_future = MagicMock() - with patch( - "groundhog_hpc.function.script_to_submittable", return_value=mock_shell_func - ) as mock_script: + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=mock_shell_func, + ) as mock_sf_prop: with patch( "groundhog_hpc.function.submit_to_executor", return_value=mock_future ) as mock_submit: @@ -187,7 +192,7 @@ def test_something(mock_submission_stack): "groundhog_hpc.compute.get_endpoint_schema", return_value={} ) as mock_schema: yield { - "script_to_submittable": mock_script, + "shell_function_prop": mock_sf_prop, "submit_to_executor": mock_submit, "get_endpoint_schema": mock_schema, "shell_function": mock_shell_func, @@ -302,17 +307,18 @@ def test_something(mock_executor): @pytest.fixture def mock_local_result(): - """Create a mock result for local subprocess execution. + """Create mock objects for local subprocess execution tests. Returns a factory function that creates: - - A mock ShellFunction that returns the result - - The mock result object itself + - A mock ShellFunction with a .cmd attribute (for patching Function.shell_function) + - A mock result object (for patching _run_shell_locally return value) Usage: def test_something(mock_local_result): shell_func, result = mock_local_result(stdout='{"result": 42}') - # Use shell_func in patches - # Use result for specific assertions + with patch.object(Function, "shell_function", new_callable=PropertyMock, return_value=shell_func): + with patch("groundhog_hpc.function._run_shell_locally", return_value=result): + ... """ def _create( @@ -327,7 +333,8 @@ def _create( result.stderr = stderr result.exception_name = exception_name - shell_func = MagicMock(return_value=result) + shell_func = MagicMock() + shell_func.cmd = "test_cmd {payload}" return shell_func, result return _create diff --git a/tests/test_compute.py b/tests/test_compute.py index 7c86d3a..3512e44 100644 --- a/tests/test_compute.py +++ b/tests/test_compute.py @@ -4,57 +4,73 @@ from unittest.mock import MagicMock, patch from uuid import UUID +import pytest + from groundhog_hpc.compute import ( - script_to_submittable, + _poll_batch_results, + build_shell_function, + submit_batch, submit_to_executor, ) +from groundhog_hpc.future import GroundhogFuture +_ENDPOINT = "12345678-1234-1234-1234-123456789abc" +_FUNCTION_ID = "ffffffff-ffff-ffff-ffff-ffffffffffff" -class TestScriptToSubmittable: - """Test the script_to_submittable function.""" - def test_creates_shell_function(self, tmp_path): - """Test that script_to_submittable creates a ShellFunction.""" - script_path = tmp_path / "test.py" - script_path.write_text("# test") - payload = "__PICKLE__:test_payload" +def _make_shell_function(name="test_func"): + sf = MagicMock() + sf.__name__ = name + return sf - with patch("groundhog_hpc.compute.template_shell_command") as mock_template: - mock_template.return_value = "echo test" - with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_shell_func: - _result = script_to_submittable( - str(script_path), "my_function", payload - ) - # Verify template was called with correct args - mock_template.assert_called_once_with( - str(script_path), "my_function", payload - ) +def _make_batch_client(function_id=_FUNCTION_ID, task_ids=None): + """Mock GC client pre-configured for batch submission.""" + task_ids = task_ids or ["tid-0", "tid-1"] + client = MagicMock() + client.register_function.return_value = function_id + client.create_batch.return_value = MagicMock() + client.batch_run.return_value = {"tasks": {function_id: task_ids}} + return client - # Verify ShellFunction was created with correct args - mock_shell_func.assert_called_once_with( - "echo test", name="my_function", walltime=None - ) - def test_uses_function_name_as_shell_function_name(self, tmp_path): - """Test that function name is used as the ShellFunction name.""" - script_path = tmp_path / "test.py" - script_path.write_text("# test") - payload = "__PICKLE__:test_payload" +def _success(result): + return {"pending": False, "status": "success", "result": result} + + +def _pending(): + return {"pending": True, "status": "unknown"} + + +class TestBuildShellFunction: + """Test the build_shell_function helper.""" + + def test_creates_shell_function_with_correct_name(self): + """Test that dots in function name are replaced with underscores.""" + with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_sf: + build_shell_function("echo test", "my.module.func") + mock_sf.assert_called_once_with( + "echo test", name="my_module_func", walltime=None + ) - with patch("groundhog_hpc.compute.template_shell_command"): - with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_shell_func: - script_to_submittable(str(script_path), "custom_func_name", payload) + def test_passes_walltime(self): + """Test that walltime is forwarded to ShellFunction.""" + with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_sf: + build_shell_function("echo test", "func", walltime=300) + assert mock_sf.call_args[1]["walltime"] == 300 - # Verify name was passed - assert mock_shell_func.call_args[1]["name"] == "custom_func_name" + def test_default_walltime_is_none(self): + """Test that walltime defaults to None.""" + with patch("groundhog_hpc.compute.gc.ShellFunction") as mock_sf: + build_shell_function("echo test", "func") + assert mock_sf.call_args[1]["walltime"] is None class TestSubmitToExecutor: """Test the submit_to_executor function.""" def test_creates_executor_and_submits(self, mock_endpoint_uuid, mock_executor): - """Test that Executor is created and submit is called.""" + """Test that Executor is created and submit is called with payload.""" mock_shell_func = MagicMock() mock_future = Future() mock_executor.submit.return_value = mock_future @@ -64,7 +80,10 @@ def test_creates_executor_and_submits(self, mock_endpoint_uuid, mock_executor): with patch("groundhog_hpc.compute.gc.Executor", return_value=mock_executor): with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=None): result = submit_to_executor( - UUID(mock_endpoint_uuid), user_config, mock_shell_func + UUID(mock_endpoint_uuid), + user_config, + mock_shell_func, + payload="test_payload", ) # Verify Executor was created with correct endpoint and config @@ -74,12 +93,30 @@ def test_creates_executor_and_submits(self, mock_endpoint_uuid, mock_executor): UUID(mock_endpoint_uuid), user_endpoint_config=user_config ) - # Verify submit was called with shell function (payload already baked in) - mock_executor.submit.assert_called_once_with(mock_shell_func) + # Verify submit was called with shell function and payload + mock_executor.submit.assert_called_once_with( + mock_shell_func, payload="test_payload" + ) # Result should be a Future (the deserializing one, not the original) assert isinstance(result, Future) + def test_passes_payload_to_executor_submit(self, mock_endpoint_uuid, mock_executor): + """Test that payload is forwarded to executor.submit as keyword argument.""" + mock_shell_func = MagicMock() + mock_future = Future() + mock_executor.submit.return_value = mock_future + + with patch("groundhog_hpc.compute.gc.Executor", return_value=mock_executor): + with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=None): + submit_to_executor( + UUID(mock_endpoint_uuid), {}, mock_shell_func, payload="abc123" + ) + + mock_executor.submit.assert_called_once_with( + mock_shell_func, payload="abc123" + ) + def test_returns_deserializing_future(self, mock_endpoint_uuid, mock_executor): """Test that a deserializing future is returned, not the original.""" mock_shell_func = MagicMock() @@ -89,7 +126,7 @@ def test_returns_deserializing_future(self, mock_endpoint_uuid, mock_executor): with patch("groundhog_hpc.compute.gc.Executor", return_value=mock_executor): with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=None): result = submit_to_executor( - UUID(mock_endpoint_uuid), {}, mock_shell_func + UUID(mock_endpoint_uuid), {}, mock_shell_func, payload="test" ) # Should return a different future than the one from executor.submit @@ -109,7 +146,10 @@ def test_walltime_in_config_passed_to_executor( with patch("groundhog_hpc.compute.gc.Executor", return_value=mock_executor): with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=None): submit_to_executor( - UUID(mock_endpoint_uuid), user_config, mock_shell_func + UUID(mock_endpoint_uuid), + user_config, + mock_shell_func, + payload="test", ) # Verify walltime was NOT extracted from config - it should still be present @@ -119,3 +159,167 @@ def test_walltime_in_config_passed_to_executor( UUID(mock_endpoint_uuid), user_endpoint_config={"account": "test", "walltime": 600}, ) + + +class TestSubmitBatch: + def test_returns_one_future_per_payload(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0", "tid-1", "tid-2"]) + mock_globus_client.return_value = client + + futures = submit_batch( + _ENDPOINT, {}, _make_shell_function(), ["p0", "p1", "p2"] + ) + + assert len(futures) == 3 + assert all(isinstance(f, GroundhogFuture) for f in futures) + + def test_each_future_has_task_id_from_batch_run(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0", "tid-1"]) + mock_globus_client.return_value = client + + futures = submit_batch(_ENDPOINT, {}, _make_shell_function(), ["p0", "p1"]) + + assert futures[0].task_id == "tid-0" + assert futures[1].task_id == "tid-1" + + def test_register_function_called_once(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0", "tid-1", "tid-2"]) + mock_globus_client.return_value = client + shell_fn = _make_shell_function() + + submit_batch(_ENDPOINT, {}, shell_fn, ["p0", "p1", "p2"]) + + client.register_function.assert_called_once_with(shell_fn) + + def test_batch_add_called_once_per_payload_with_payload_kwarg( + self, mock_globus_client + ): + client = _make_batch_client(task_ids=["tid-0", "tid-1"]) + mock_globus_client.return_value = client + batch_mock = client.create_batch.return_value + + submit_batch(_ENDPOINT, {}, _make_shell_function(), ["p0", "p1"]) + + assert batch_mock.add.call_count == 2 + batch_mock.add.assert_any_call(_FUNCTION_ID, kwargs={"payload": "p0"}) + batch_mock.add.assert_any_call(_FUNCTION_ID, kwargs={"payload": "p1"}) + + def test_endpoint_schema_filtering_applied(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0"]) + mock_globus_client.return_value = client + + schema = {"properties": {"account": {"type": "string"}}} + with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=schema): + submit_batch( + _ENDPOINT, + {"account": "proj", "unexpected_key": "val"}, + _make_shell_function(), + ["p0"], + ) + + _, create_batch_kwargs = client.create_batch.call_args + config = create_batch_kwargs["user_endpoint_config"] + assert "account" in config + assert "unexpected_key" not in config + + def test_futures_resolve_via_polling_thread(self, mock_globus_client): + mock_shell_result = MagicMock() + mock_shell_result.returncode = 0 + mock_shell_result.stdout = '"hello"' + mock_shell_result.stderr = "" + + client = _make_batch_client(task_ids=["tid-0"]) + mock_globus_client.return_value = client + + # Resolve the future synchronously by patching _poll_batch_results + def resolve_immediately(task_id_to_future, client, poll_interval=1.0): + task_id_to_future["tid-0"].set_result(mock_shell_result) + + with patch( + "groundhog_hpc.compute._poll_batch_results", side_effect=resolve_immediately + ): + futures = submit_batch(_ENDPOINT, {}, _make_shell_function(), ["p0"]) + + assert futures[0].result(timeout=1) == "hello" + + def test_failed_tasks_propagate_exception(self, mock_globus_client): + client = _make_batch_client(task_ids=["tid-0"]) + mock_globus_client.return_value = client + + def fail_immediately(task_id_to_future, client, poll_interval=1.0): + task_id_to_future["tid-0"].set_exception(RuntimeError("task blew up")) + + with patch( + "groundhog_hpc.compute._poll_batch_results", side_effect=fail_immediately + ): + futures = submit_batch(_ENDPOINT, {}, _make_shell_function(), ["p0"]) + + with pytest.raises(RuntimeError, match="task blew up"): + futures[0].result(timeout=1) + + +class TestPollBatchResults: + def test_resolves_successful_task(self): + mock_shell_result = MagicMock() + mock_shell_result.returncode = 0 + mock_shell_result.stdout = '"done"' + + fut = Future() + client = MagicMock() + client.get_batch_result.return_value = {"tid-0": _success(mock_shell_result)} + + _poll_batch_results({"tid-0": fut}, client, poll_interval=0) + + assert fut.done() + assert fut.result() is mock_shell_result + + def test_failed_task_sets_exception(self): + mock_exc = MagicMock() + mock_exc.reraise.side_effect = ValueError("remote error") + + fut = Future() + client = MagicMock() + client.get_batch_result.return_value = { + "tid-0": {"pending": False, "status": "failed", "exception": mock_exc} + } + + _poll_batch_results({"tid-0": fut}, client, poll_interval=0) + + assert fut.done() + with pytest.raises(ValueError, match="remote error"): + fut.result() + + def test_pending_task_stays_unresolved_until_next_poll(self): + mock_shell_result = MagicMock() + mock_shell_result.returncode = 0 + mock_shell_result.stdout = '"done"' + + fut = Future() + client = MagicMock() + client.get_batch_result.side_effect = [ + {"tid-0": _pending()}, + {"tid-0": _success(mock_shell_result)}, + ] + + _poll_batch_results({"tid-0": fut}, client, poll_interval=0) + + assert client.get_batch_result.call_count == 2 + assert fut.done() + + def test_polls_only_remaining_pending_tasks(self): + r0, r1 = MagicMock(), MagicMock() + r0.returncode = r1.returncode = 0 + r0.stdout = r1.stdout = '"ok"' + + fut0, fut1 = Future(), Future() + client = MagicMock() + client.get_batch_result.side_effect = [ + {"tid-0": _success(r0), "tid-1": _pending()}, + {"tid-1": _success(r1)}, + ] + + _poll_batch_results({"tid-0": fut0, "tid-1": fut1}, client, poll_interval=0) + + second_call_ids = client.get_batch_result.call_args_list[1][0][0] + assert second_call_ids == ["tid-1"] + assert fut0.done() and fut1.done() diff --git a/tests/test_function.py b/tests/test_function.py index 76dab4e..2ca1084 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,12 +1,14 @@ """Tests for the Function class.""" import os -from unittest.mock import MagicMock, patch +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock, PropertyMock, patch import pytest from groundhog_hpc.errors import ModuleImportError from groundhog_hpc.function import Function +from groundhog_hpc.future import GroundhogFuture from tests.test_fixtures import simple_function # Alias for backward compatibility with existing tests @@ -105,12 +107,11 @@ def test_script_path_raises_when_uninspectable(self): ): _ = func.script_path - def test_submit_creates_shell_function(self, tmp_path, mock_endpoint_uuid): - """Test that submit creates a shell function using script_to_submittable.""" + def test_submit_uses_shell_function_property(self, tmp_path, mock_endpoint_uuid): + """Test that submit uses the cached shell_function property.""" script_path = tmp_path / "test_script.py" - script_content = "# test script content" - script_path.write_text(script_content) + script_path.write_text("# test script content") func = Function(dummy_function, endpoint=mock_endpoint_uuid) func._script_path = str(script_path) @@ -118,26 +119,24 @@ def test_submit_creates_shell_function(self, tmp_path, mock_endpoint_uuid): mock_shell_func = MagicMock() mock_future = MagicMock() - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, - ) as mock_script_to_submittable: + ): with patch( "groundhog_hpc.function.submit_to_executor", return_value=mock_future, - ): + ) as mock_submit: with patch( "groundhog_hpc.compute.get_endpoint_schema", return_value={} ): func.submit() - # Verify script_to_submittable was called with correct arguments - mock_script_to_submittable.assert_called_once() - call_args = mock_script_to_submittable.call_args[0] - assert call_args[0] == str(script_path) - assert ( - call_args[1] == "simple_function" - ) # dummy_function is an alias to simple_function + # Verify submit_to_executor was called with the cached shell_function + mock_submit.assert_called_once() + assert mock_submit.call_args[1]["shell_function"] is mock_shell_func class TestSubmitMethod: @@ -188,9 +187,9 @@ def test_submit_serializes_arguments( call_args = mock_serialize.call_args[0][0] assert call_args == ((1, 2), {"kwarg1": "value1"}) - # Verify script_to_submittable received the serialized payload + # Verify submit_to_executor received the serialized payload assert ( - mock_submission_stack["script_to_submittable"].call_args[0][2] + mock_submission_stack["submit_to_executor"].call_args[1]["payload"] == "serialized_payload" ) @@ -265,21 +264,6 @@ def test_callsite_walltime_goes_to_config( config = mock_submit.call_args[1]["user_endpoint_config"] assert config["walltime"] == 120 - def test_function_walltime_sets_shellfunction_walltime( - self, function_with_script, mock_submission_stack - ): - """Test that Function.walltime attribute sets ShellFunction walltime (escape hatch).""" - # Create function and manually set walltime (escape hatch) - func = function_with_script() - func.walltime = 120 - - func.submit() - - # Verify script_to_submittable was called with walltime parameter - mock_script_to_submittable = mock_submission_stack["script_to_submittable"] - call_args = mock_script_to_submittable.call_args - assert call_args[1]["walltime"] == 120 - def test_callsite_user_config_overrides_default( self, function_with_script, mock_submission_stack ): @@ -363,6 +347,46 @@ def test_default_worker_init_preserved_when_no_callsite_override( assert "worker_init" in config assert default_worker_init in config["worker_init"] + def test_executor_kwargs_forwarded_to_submit_to_executor( + self, function_with_script, mock_submission_stack + ): + """Test that executor_kwargs are forwarded to submit_to_executor.""" + func = function_with_script() + + func.submit(executor_kwargs={"amqp_port": 5671}) + + mock_submit = mock_submission_stack["submit_to_executor"] + assert mock_submit.call_args[1]["executor_kwargs"] == {"amqp_port": 5671} + + def test_executor_kwargs_defaults_to_none( + self, function_with_script, mock_submission_stack + ): + """Test that executor_kwargs defaults to None when not provided.""" + func = function_with_script() + + func.submit() + + mock_submit = mock_submission_stack["submit_to_executor"] + assert mock_submit.call_args[1]["executor_kwargs"] is None + + def test_executor_kwargs_does_not_bleed_into_user_endpoint_config( + self, function_with_script, mock_submission_stack + ): + """Test that executor_kwargs keys are not added to user_endpoint_config.""" + func = function_with_script() + + mock_schema = {"properties": {"account": {"type": "string"}}} + mock_submission_stack["get_endpoint_schema"].return_value = mock_schema + + func.submit( + executor_kwargs={"amqp_port": 5671}, + user_endpoint_config={"account": "x"}, + ) + + mock_submit = mock_submission_stack["submit_to_executor"] + config = mock_submit.call_args[1]["user_endpoint_config"] + assert "amqp_port" not in config + class TestLocalMethod: """Test the local() method for running functions in local subprocess.""" @@ -370,16 +394,9 @@ class TestLocalMethod: def test_local_executes_function_and_returns_result( self, tmp_path, mock_local_result ): - """Test that local() executes the function via ShellFunction and returns result.""" - # Create a test script + """Test that local() executes the function and returns deserialized result.""" script_path = tmp_path / "test_local.py" - script_content = """import groundhog_hpc as hog - -@hog.function() -def add(a, b): - return a + b -""" - script_path.write_text(script_content) + script_path.write_text("# test") def add(a, b): return a + b @@ -387,17 +404,21 @@ def add(a, b): func = Function(add) func._script_path = str(script_path) - # Create mock result - shell_func, result = mock_local_result(stdout='{"result": 5}') + shell_func, run_result = mock_local_result(stdout='{"result": 5}') - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=shell_func, ): with patch( - "groundhog_hpc.function.deserialize_stdout", return_value=(None, 5) - ) as mock_deserialize: - result_value = func.local(2, 3) + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch( + "groundhog_hpc.function.deserialize_stdout", return_value=(None, 5) + ) as mock_deserialize: + result_value = func.local(2, 3) assert result_value == 5 mock_deserialize.assert_called_once_with('{"result": 5}') @@ -410,74 +431,181 @@ def test_local_serializes_arguments(self, tmp_path, mock_local_result): func = Function(dummy_function) func._script_path = str(script_path) - shell_func, result = mock_local_result(stdout='{"result": "success"}') + shell_func, run_result = mock_local_result(stdout='{"result": "success"}') with patch( "groundhog_hpc.function.serialize", return_value="serialized" ) as mock_serialize: - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=shell_func, ): with patch( - "groundhog_hpc.function.deserialize_stdout", - return_value=(None, "success"), + "groundhog_hpc.function._run_shell_locally", return_value=run_result ): - func.local(1, 2, key="value") + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, "success"), + ): + func.local(1, 2, key="value") - # Verify serialize was called with args, kwargs, and proxy_threshold_mb=1.0 mock_serialize.assert_called_once() call_args = mock_serialize.call_args[0][0] call_kwargs = mock_serialize.call_args[1] assert call_args == ((1, 2), {"key": "value"}) assert call_kwargs.get("proxy_threshold_mb") == 1.0 - def test_local_runs_in_temporary_directory(self, tmp_path): - """Test that local() sets GC_TASK_SANDBOX_DIR to a temporary directory.""" + def test_gc_task_sandbox_dir_not_set_on_parent_process( + self, tmp_path, mock_local_result + ): + """local() must not mutate os.environ with GC_TASK_SANDBOX_DIR.""" script_path = tmp_path / "test_local.py" script_path.write_text("# test") func = Function(dummy_function) func._script_path = str(script_path) - mock_result = MagicMock() - mock_result.returncode = 0 - mock_result.stdout = "result" - mock_result.stderr = "" - mock_result.exception_name = None + shell_func, run_result = mock_local_result() + + original = os.environ.pop("GC_TASK_SANDBOX_DIR", None) + try: + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=shell_func, + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, "result"), + ): + func.local() - mock_shell_function = MagicMock(return_value=mock_result) + assert "GC_TASK_SANDBOX_DIR" not in os.environ + finally: + if original is not None: + os.environ["GC_TASK_SANDBOX_DIR"] = original - # Store original env var if it exists - original_sandbox_dir = os.environ.get("GC_TASK_SANDBOX_DIR") + def test_gc_task_sandbox_dir_not_overwritten_if_already_set( + self, tmp_path, mock_local_result + ): + """local() must not overwrite an externally set GC_TASK_SANDBOX_DIR.""" + script_path = tmp_path / "test_local.py" + script_path.write_text("# test") + func = Function(dummy_function) + func._script_path = str(script_path) + + shell_func, run_result = mock_local_result() + + os.environ["GC_TASK_SANDBOX_DIR"] = "/my/custom/dir" try: - # Clear it for this test - if "GC_TASK_SANDBOX_DIR" in os.environ: - del os.environ["GC_TASK_SANDBOX_DIR"] + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=shell_func, + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, "result"), + ): + func.local() + + assert os.environ["GC_TASK_SANDBOX_DIR"] == "/my/custom/dir" + finally: + del os.environ["GC_TASK_SANDBOX_DIR"] + + def test_two_concurrent_local_calls_dont_interfere( + self, tmp_path, mock_local_result + ): + """Concurrent local() calls must not share GC_TASK_SANDBOX_DIR via os.environ.""" + import threading + + script_path = tmp_path / "test_local.py" + script_path.write_text("# test") + + func = Function(dummy_function) + func._script_path = str(script_path) + seen_dirs: list[str] = [] + + def capture_env(cmd_template, payload, tmpdir): + seen_dirs.append(os.environ.get("GC_TASK_SANDBOX_DIR", "NOT_SET")) + mock = MagicMock() + mock.returncode = 0 + mock.stdout = '"ok"' + mock.stderr = "" + mock.exception_name = None + return mock + + shell_func, _ = mock_local_result() + + os.environ.pop("GC_TASK_SANDBOX_DIR", None) + + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=shell_func, + ): with patch( - "groundhog_hpc.function.script_to_submittable", - return_value=mock_shell_function, + "groundhog_hpc.function._run_shell_locally", side_effect=capture_env ): with patch( "groundhog_hpc.function.deserialize_stdout", - return_value=(None, "result"), + return_value=(None, "ok"), ): - func.local() + threads = [threading.Thread(target=func.local) for _ in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() - # Verify GC_TASK_SANDBOX_DIR was set - assert "GC_TASK_SANDBOX_DIR" in os.environ - sandbox_dir = os.environ["GC_TASK_SANDBOX_DIR"] - assert isinstance(sandbox_dir, str) - assert len(sandbox_dir) > 0 + # Neither thread should have seen GC_TASK_SANDBOX_DIR in os.environ + assert all(d == "NOT_SET" for d in seen_dirs) + assert "GC_TASK_SANDBOX_DIR" not in os.environ - finally: - # Restore original state - if original_sandbox_dir is not None: - os.environ["GC_TASK_SANDBOX_DIR"] = original_sandbox_dir - elif "GC_TASK_SANDBOX_DIR" in os.environ: - del os.environ["GC_TASK_SANDBOX_DIR"] + def test_local_passes_tmpdir_to_run_shell_locally( + self, tmp_path, mock_local_result + ): + """local() passes a real tmpdir path to _run_shell_locally.""" + script_path = tmp_path / "test_local.py" + script_path.write_text("# test") + + func = Function(dummy_function) + func._script_path = str(script_path) + + shell_func, run_result = mock_local_result() + + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=shell_func, + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ) as mock_run: + with patch("groundhog_hpc.function.serialize", return_value="PAYLOAD"): + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, "result"), + ): + func.local() + + # Third argument is tmpdir; second is the serialized payload + _, call_payload, call_tmpdir = mock_run.call_args[0] + assert call_payload == "PAYLOAD" + assert isinstance(call_tmpdir, str) and len(call_tmpdir) > 0 def test_local_raises_if_script_path_unavailable(self): """Test that local() raises ValueError if script path cannot be determined.""" @@ -488,7 +616,6 @@ def local_func(): func = Function(local_func) func._script_path = None - # Mock inspect.getfile to raise TypeError (e.g., for built-in functions) with patch( "groundhog_hpc.function.inspect.getfile", side_effect=TypeError("not a file"), @@ -496,104 +623,182 @@ def local_func(): with pytest.raises(ValueError, match="Could not determine script path"): func.local() - def test_local_uses_script_to_submittable(self, tmp_path, mock_local_result): - """Test that local() uses script_to_submittable to create ShellFunction.""" + def test_local_uses_shell_function_property(self, tmp_path, mock_local_result): + """local() accesses the cached shell_function property for .cmd.""" script_path = tmp_path / "test_local.py" script_path.write_text("# test") func = Function(dummy_function) func._script_path = str(script_path) - # Set the import flag to allow .local() call - import sys - - test_module = sys.modules.get("tests.test_fixtures") - test_module.__groundhog_imported__ = True + shell_func, run_result = mock_local_result(stdout="result") - shell_func, result = mock_local_result(stdout="result") - - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=shell_func, - ) as mock_script_to_submittable: + ) as mock_sf_prop: with patch( - "groundhog_hpc.function.deserialize_stdout", - return_value=(None, "result"), + "groundhog_hpc.function._run_shell_locally", return_value=run_result ): - func.local() - - # Verify script_to_submittable was called with script path, function name, and payload - assert mock_script_to_submittable.call_count == 1 - call_args = mock_script_to_submittable.call_args[0] - assert call_args[0] == str(script_path) - assert call_args[1] == "simple_function" - assert len(call_args) == 3 # script_path, function_name, payload - - def test_local_calls_shell_function(self, tmp_path, mock_local_result): - """Test that local() calls the ShellFunction returned by script_to_submittable.""" - script_path = tmp_path / "test_local.py" - script_path.write_text("# test") - - func = Function(dummy_function) - func._script_path = str(script_path) - - shell_func, result = mock_local_result(stdout="result") - - with patch( - "groundhog_hpc.function.script_to_submittable", - return_value=shell_func, - ): - with patch("groundhog_hpc.function.serialize", return_value="ABC123"): with patch( "groundhog_hpc.function.deserialize_stdout", return_value=(None, "result"), ): func.local() - # Verify ShellFunction was called (invoked via __call__) - shell_func.assert_called_once() - # Verify it was called with no arguments (ShellFunction handles its own execution) - assert shell_func.call_args[0] == () + mock_sf_prop.assert_called() - def test_local_infers_script_path_from_function(self, tmp_path): + def test_local_infers_script_path_from_function(self, tmp_path, mock_local_result): """Test that local() can infer script path from function's source file.""" - # Create a test script script_path = tmp_path / "inferred_script.py" - script_content = """def my_function(): - return 42 -""" - script_path.write_text(script_content) + script_path.write_text("def my_function():\n return 42\n") def my_function(): return 42 func = Function(my_function) - func._script_path = None # Force it to infer - - mock_result = MagicMock() - mock_result.returncode = 0 - mock_result.stdout = "42" - mock_result.stderr = "" - mock_result.exception_name = None + func._script_path = None - mock_shell_function = MagicMock(return_value=mock_result) + shell_func, run_result = mock_local_result(stdout="42") + run_result.returncode = 0 - # Mock inspect.getfile to return our test script with patch( "groundhog_hpc.function.inspect.getfile", return_value=str(script_path) ): - with patch( - "groundhog_hpc.function.script_to_submittable", - return_value=mock_shell_function, + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=shell_func, ): with patch( - "groundhog_hpc.function.deserialize_stdout", return_value=(None, 42) + "groundhog_hpc.function._run_shell_locally", return_value=run_result ): - result = func.local() + with patch( + "groundhog_hpc.function.deserialize_stdout", + return_value=(None, 42), + ): + result = func.local() assert result == 42 +class TestShellCommandProperty: + """Test the shell_command lazy-cached property.""" + + def test_calls_template_with_script_path_and_name(self, tmp_path): + """shell_command calls template_shell_command with correct args.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + + with patch( + "groundhog_hpc.function.template_shell_command", + return_value="parameterized_cmd", + ) as mock_template: + result = func.shell_command + + mock_template.assert_called_once_with(func._script_path, func.name) + assert result == "parameterized_cmd" + + def test_caches_result_on_second_access(self, tmp_path): + """shell_command returns cached value without re-calling the template.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + + with patch( + "groundhog_hpc.function.template_shell_command", + return_value="cmd1", + ) as mock_template: + first = func.shell_command + second = func.shell_command + + mock_template.assert_called_once() + assert first == second == "cmd1" + + +class TestShellFunctionProperty: + """Test the shell_function lazy-cached property.""" + + def test_calls_build_shell_function_with_correct_args(self, tmp_path): + """shell_function calls build_shell_function with shell_command, name, walltime.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + func.walltime = 120 + + mock_sf = MagicMock() + + with patch( + "groundhog_hpc.function.template_shell_command", + return_value="paramcmd", + ): + with patch( + "groundhog_hpc.function.build_shell_function", + return_value=mock_sf, + ) as mock_build: + result = func.shell_function + + mock_build.assert_called_once_with("paramcmd", func.name, walltime=120) + assert result is mock_sf + + def test_caches_result_on_second_access(self, tmp_path): + """shell_function returns cached value without re-calling build_shell_function.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + + mock_sf = MagicMock() + + with patch( + "groundhog_hpc.function.template_shell_command", + return_value="cmd", + ): + with patch( + "groundhog_hpc.function.build_shell_function", + return_value=mock_sf, + ) as mock_build: + first = func.shell_function + second = func.shell_function + + mock_build.assert_called_once() + assert first is second is mock_sf + + def test_default_walltime_is_none(self, tmp_path): + """shell_function passes walltime=None when not set.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + + with patch( + "groundhog_hpc.function.template_shell_command", + return_value="cmd", + ): + with patch( + "groundhog_hpc.function.build_shell_function", + return_value=MagicMock(), + ) as mock_build: + func.shell_function + + assert mock_build.call_args[1]["walltime"] is None + + def test_walltime_flows_into_shell_function(self, tmp_path): + """walltime set before first access is used by build_shell_function.""" + func = Function(dummy_function) + func._script_path = str(tmp_path / "fake.py") + func.walltime = 300 + + with patch( + "groundhog_hpc.function.template_shell_command", + return_value="cmd", + ): + with patch( + "groundhog_hpc.function.build_shell_function", + return_value=MagicMock(), + ) as mock_build: + func.shell_function + + assert mock_build.call_args[1]["walltime"] == 300 + + class TestLocalAlwaysUsesSubprocess: """Test that .local() always uses subprocess (no direct call fallback).""" @@ -624,19 +829,345 @@ def test_func(x): test_module = sys.modules[func._wrapped_function.__module__] test_module.__groundhog_imported__ = True - shell_func, result = mock_local_result(stdout="84") + shell_func, run_result = mock_local_result(stdout="84") - # Mock script_to_submittable to verify subprocess is used - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=shell_func, - ) as mock_script_to_submittable: + ): with patch( - "groundhog_hpc.function.deserialize_stdout", return_value=(None, 84) - ): - result_value = func.local(42) + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ) as mock_run: + with patch( + "groundhog_hpc.function.deserialize_stdout", return_value=(None, 84) + ): + result_value = func.local(42) - # Should always use subprocess (ShellFunction) + # Always uses _run_shell_locally (never calls the function directly) assert result_value == 84 - mock_script_to_submittable.assert_called_once() - shell_func.assert_called_once() + mock_run.assert_called_once() + + +class TestBatchSubmit: + """Tests for Function.batch_submit().""" + + def _make_func(self, tmp_path, mock_endpoint_uuid): + script_path = tmp_path / "test_script.py" + script_path.write_text("# test") + func = Function(dummy_function, endpoint=mock_endpoint_uuid) + func._script_path = str(script_path) + return func + + def _mock_submit_batch(self, n=3): + """Return a mock submit_batch that produces n GroundhogFutures.""" + from concurrent.futures import Future as CF + + futures = [] + for i in range(n): + cf = CF() + cf.set_result(MagicMock(returncode=0, stdout=f'"{i}"', stderr="")) + gf = MagicMock() + gf._task_id = f"tid-{i}" + futures.append(gf) + return MagicMock(return_value=futures), futures + + def test_raises_without_import_flag(self, tmp_path, mock_endpoint_uuid): + import sys + + func = self._make_func(tmp_path, mock_endpoint_uuid) + test_module = sys.modules.get("tests.test_fixtures") + had_flag = hasattr(test_module, "__groundhog_imported__") + if had_flag: + del test_module.__groundhog_imported__ + try: + with pytest.raises(ModuleImportError): + func.batch_submit(args=[(1,)]) + finally: + if had_flag: + test_module.__groundhog_imported__ = True + + def test_raises_when_args_and_kwargs_both_empty(self, tmp_path, mock_endpoint_uuid): + func = self._make_func(tmp_path, mock_endpoint_uuid) + with pytest.raises(ValueError, match="both empty"): + func.batch_submit() + + def test_returns_one_future_per_task( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=3) + with patch("groundhog_hpc.function.submit_batch", mock_batch): + result = func.batch_submit(args=[(1,), (2,), (3,)]) + assert len(result) == 3 + + def test_args_and_kwargs_zipped_with_fill( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=2) + captured = [] + + def fake_serialize(data, **kw): + captured.append(data) + return f"payload_{len(captured)}" + + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", side_effect=fake_serialize): + func.batch_submit(args=[(1,), (2,)], kwargs=[{"k": "v"}]) + + assert captured[0] == ((1,), {"k": "v"}) + assert captured[1] == ((2,), {}) + + def test_kwargs_only_batch_uses_empty_args_tuple( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=2) + captured = [] + + def fake_serialize(data, **kw): + captured.append(data) + return "p" + + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", side_effect=fake_serialize): + func.batch_submit(kwargs=[{"x": 1}, {"x": 2}]) + + assert captured[0] == ((), {"x": 1}) + assert captured[1] == ((), {"x": 2}) + + def test_uses_resolved_endpoint( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=1) + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", return_value="p"): + func.batch_submit(args=[(1,)]) + endpoint_arg = mock_batch.call_args[0][0] + from uuid import UUID + + assert endpoint_arg == UUID(mock_endpoint_uuid) + + def test_callsite_endpoint_overrides_decorator( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + other_uuid = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, futures = self._mock_submit_batch(n=1) + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", return_value="p"): + func.batch_submit(args=[(1,)], endpoint=other_uuid) + from uuid import UUID + + assert mock_batch.call_args[0][0] == UUID(other_uuid) + + def test_function_name_and_config_set_on_each_future( + self, tmp_path, mock_endpoint_uuid, mock_submission_stack + ): + func = self._make_func(tmp_path, mock_endpoint_uuid) + mock_batch, mock_futures = self._mock_submit_batch(n=3) + with patch("groundhog_hpc.function.submit_batch", mock_batch): + with patch("groundhog_hpc.function.serialize", return_value="p"): + result = func.batch_submit(args=[(1,), (2,), (3,)]) + for f in result: + assert f.function_name == func.name + assert f.user_endpoint_config is not None + + +class TestBatchLocal: + """Tests for Function.batch_local().""" + + def _make_func(self, tmp_path): + script_path = tmp_path / "test_script.py" + script_path.write_text("# test") + func = Function(dummy_function) + func._script_path = str(script_path) + return func + + def _mock_shell_func(self): + sf = MagicMock() + sf.cmd = "test_cmd {payload}" + return sf + + def _make_run_result(self, stdout='"ok"'): + r = MagicMock() + r.returncode = 0 + r.stdout = stdout + r.stderr = "" + r.exception_name = None + return r + + def test_raises_without_import_flag(self, tmp_path): + import sys + + func = self._make_func(tmp_path) + test_module = sys.modules.get("tests.test_fixtures") + had_flag = hasattr(test_module, "__groundhog_imported__") + if had_flag: + del test_module.__groundhog_imported__ + try: + with pytest.raises(ModuleImportError): + func.batch_local(args=[(1,)]) + finally: + if had_flag: + test_module.__groundhog_imported__ = True + + def test_raises_when_args_and_kwargs_both_empty(self, tmp_path): + func = self._make_func(tmp_path) + with pytest.raises(ValueError, match="both empty"): + func.batch_local() + + def test_returns_one_future_per_task(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + futures = func.batch_local(args=[(1,), (2,)]) + assert len(futures) == 2 + assert all(isinstance(f, GroundhogFuture) for f in futures) + + def test_returns_immediately_without_blocking(self, tmp_path): + import threading + + func = self._make_func(tmp_path) + started = threading.Event() + finished = threading.Event() + + def slow_run(cmd_template, payload, tmpdir): + started.set() + finished.wait(timeout=2) + return self._make_run_result() + + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", side_effect=slow_run + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + futures = func.batch_local(args=[(1,)]) + + # batch_local returned before the worker finished + assert len(futures) == 1 + assert not futures[0].done() + finished.set() + + def test_args_and_kwargs_zipped_with_fill(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + captured = [] + + def fake_serialize(data, **kw): + captured.append(data) + return "p" + + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch( + "groundhog_hpc.function.serialize", side_effect=fake_serialize + ): + func.batch_local(args=[(1,), (2,)], kwargs=[{"k": "v"}]) + + assert captured[0] == ((1,), {"k": "v"}) + assert captured[1] == ((2,), {}) + + def test_executor_kwargs_passed_to_thread_pool(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + with patch( + "groundhog_hpc.function.ThreadPoolExecutor", + wraps=ThreadPoolExecutor, + ) as mock_tpe: + func.batch_local( + args=[(1,)], executor_kwargs={"max_workers": 2} + ) + mock_tpe.assert_called_once_with(max_workers=2) + + def test_gc_task_sandbox_dir_not_set_on_parent_process(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + os.environ.pop("GC_TASK_SANDBOX_DIR", None) + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + futures = func.batch_local(args=[(1,)]) + # Wait for workers to finish + for f in futures: + try: + f.result(timeout=2) + except Exception: + pass + assert "GC_TASK_SANDBOX_DIR" not in os.environ + + def test_task_id_is_none_for_all_local_futures(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch("groundhog_hpc.function.serialize", return_value="p"): + futures = func.batch_local(args=[(1,), (2,)]) + for f in futures: + assert f.task_id is None + + def test_serialize_called_once_per_task(self, tmp_path): + func = self._make_func(tmp_path) + run_result = self._make_run_result() + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=self._mock_shell_func(), + ): + with patch( + "groundhog_hpc.function._run_shell_locally", return_value=run_result + ): + with patch( + "groundhog_hpc.function.serialize", return_value="p" + ) as mock_ser: + func.batch_local(args=[(1,), (2,), (3,)]) + assert mock_ser.call_count == 3 diff --git a/tests/test_future.py b/tests/test_future.py index db1d884..693b42a 100644 --- a/tests/test_future.py +++ b/tests/test_future.py @@ -88,27 +88,29 @@ def test_handles_shell_execution_errors(self): assert exc_info.value.returncode == 1 assert "something went wrong" in exc_info.value.stderr - def test_preserves_task_id(self): - """Test that task_id attribute is preserved on the deserializing future.""" + def test_task_id_falls_through_to_original_future(self): + """task_id reads from original future when _task_id is None.""" original = Future() - original.task_id = "test-task-123" + original.task_id = "abc-123" deserializing = GroundhogFuture(original) - # Create a successful result - mock_shell_result = MagicMock() - mock_shell_result.returncode = 0 - mock_shell_result.stdout = '"test"' + assert deserializing.task_id == "abc-123" - original.set_result(mock_shell_result) + def test_explicit_task_id_takes_precedence(self): + """_task_id takes precedence over the original future's task_id attribute.""" + original = Future() + original.task_id = "from-future" + deserializing = GroundhogFuture(original) + deserializing._task_id = "explicit" - # Wait for callback - import time + assert deserializing.task_id == "explicit" - time.sleep(0.01) + def test_task_id_returns_none_when_neither_source_has_it(self): + """task_id returns None without raising when the underlying future has no task_id.""" + original = Future() # plain Future, no task_id attribute + deserializing = GroundhogFuture(original) - # Task ID should be preserved - assert hasattr(deserializing, "task_id") - assert deserializing.task_id == "test-task-123" + assert deserializing.task_id is None def test_shell_result_property_returns_raw_result(self): """Test that shell_result property provides access to raw ShellResult.""" diff --git a/tests/test_mark_import_safe.py b/tests/test_mark_import_safe.py index 48b94a6..5ca56f1 100644 --- a/tests/test_mark_import_safe.py +++ b/tests/test_mark_import_safe.py @@ -2,7 +2,7 @@ import sys import types -from unittest.mock import Mock, patch +from unittest.mock import Mock, PropertyMock, patch import pytest @@ -174,20 +174,30 @@ def my_func(): # Verify flag is set assert module.__groundhog_imported__ is True - # Mock script_to_submittable to avoid actual subprocess execution + # Mock _run_shell_locally to avoid actual subprocess execution mock_shell_func = Mock() - mock_result = Mock() - mock_result.returncode = 0 - mock_result.stdout = 'hello\n__GROUNDHOG_RESULT__\n"hello"' - mock_result.stderr = "" - mock_shell_func.return_value = mock_result - - with patch( - "groundhog_hpc.function.script_to_submittable", return_value=mock_shell_func + mock_shell_func.cmd = "test {payload}" + mock_run_result = Mock() + mock_run_result.returncode = 0 + mock_run_result.stdout = 'hello\n__GROUNDHOG_RESULT__\n"hello"' + mock_run_result.stderr = "" + mock_run_result.exception_name = None + + from groundhog_hpc.function import Function + + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, + return_value=mock_shell_func, ): - # Now .local() should work (won't raise ModuleImportError) - result = module.my_func.local() - assert result == "hello" + with patch( + "groundhog_hpc.function._run_shell_locally", + return_value=mock_run_result, + ): + # Now .local() should work (won't raise ModuleImportError) + result = module.my_func.local() + assert result == "hello" # Cleanup del sys.modules["test_module5"] diff --git a/tests/test_method.py b/tests/test_method.py index 28b321b..ac9c33d 100644 --- a/tests/test_method.py +++ b/tests/test_method.py @@ -1,6 +1,6 @@ """Tests for the Method class.""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch from groundhog_hpc.function import Function, Method @@ -93,19 +93,22 @@ def compute(x): mock_shell_func = MagicMock() mock_future = MagicMock() - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, - ) as mock_script_to_submittable: + ): with patch( "groundhog_hpc.function.submit_to_executor", return_value=mock_future, - ): + ) as mock_submit: with patch( "groundhog_hpc.compute.get_endpoint_schema", return_value={} ): method.submit(5) - # Verify qualname was passed correctly - call_args = mock_script_to_submittable.call_args[0] - assert call_args[1] == "MyClass.compute" + # Verify the function name (qualname) is used — visible in the shell_function property name + assert method.name == "MyClass.compute" + # Verify submit was called (method uses the same submit path as Function) + mock_submit.assert_called_once() diff --git a/tests/test_pep723_integration.py b/tests/test_pep723_integration.py index 05dd381..3807ecf 100644 --- a/tests/test_pep723_integration.py +++ b/tests/test_pep723_integration.py @@ -7,7 +7,7 @@ import os import tempfile from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch from uuid import UUID import pytest @@ -69,8 +69,10 @@ def test_func(): "qos": {"type": "string"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -152,8 +154,10 @@ def test_func(): "partition": {"type": "string"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -232,8 +236,10 @@ def test_func(): "cores": {"type": "integer"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -304,8 +310,10 @@ def test_func(): "qos": {"type": "string"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -372,8 +380,10 @@ def test_func(): # Mock schema that includes worker_init mock_schema = {"properties": {"worker_init": {"type": "string"}}} - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -455,8 +465,10 @@ def test_func(): } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -482,8 +494,10 @@ def test_func(): # Reset mock mock_submit.reset_mock() - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( @@ -561,8 +575,10 @@ def test_func(): "qos": {"type": "string"}, } } - with patch( - "groundhog_hpc.function.script_to_submittable", + with patch.object( + Function, + "shell_function", + new_callable=PropertyMock, return_value=mock_shell_func, ): with patch( diff --git a/tests/test_templating.py b/tests/test_templating.py index ac27f7f..aeccb66 100644 --- a/tests/test_templating.py +++ b/tests/test_templating.py @@ -28,7 +28,7 @@ def foo(): script_path.write_text(script_content) # Should not raise any errors - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") assert isinstance(shell_command, str) # User script should be included as-is (with __main__ block) assert 'if __name__ == "__main__":' in shell_command @@ -49,13 +49,14 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") - # Should create both user script and runner - assert "_runner.py" in shell_command + # Should create runner in TASK_DIR + assert "$TASK_DIR/runner.py" in shell_command # Runner should import the user script assert ( - 'module = import_user_script("test_script", "test_script-' in shell_command + 'module = import_user_script("test_script", "user_script.py")' + in shell_command ) # Runner should invoke the target function using attrgetter assert 'func = attrgetter("foo")(module)' in shell_command @@ -76,7 +77,7 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") # Runner should contain the metadata assert 'requires-python = ">=3.12"' in shell_command @@ -105,7 +106,7 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") # Runner should contain the [tool.uv] section assert "[tool.uv]" in shell_command @@ -130,7 +131,7 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") # Should NOT contain --managed-python (it's now in [tool.uv]) assert "--managed-python" not in shell_command @@ -154,7 +155,7 @@ def foo(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "foo", "test_payload") + shell_command = template_shell_command(str(script_path), "foo") # Check that it's a non-empty string assert isinstance(shell_command, str) @@ -175,9 +176,7 @@ def test_func(): """ script_path.write_text(script_content) - shell_command = template_shell_command( - str(script_path), "test_func", "test_payload" - ) + shell_command = template_shell_command(str(script_path), "test_func") # Should include the basename assert "my_script" in shell_command @@ -197,9 +196,7 @@ def my_function(): """ script_path.write_text(script_content) - shell_command = template_shell_command( - str(script_path), "my_function", "test_payload" - ) + shell_command = template_shell_command(str(script_path), "my_function") assert "my_function" in shell_command @@ -218,11 +215,10 @@ def func(): """ script_path.write_text(script_content) - test_payload = "MY_TEST_PAYLOAD_12345" - shell_command = template_shell_command(str(script_path), "func", test_payload) + shell_command = template_shell_command(str(script_path), "func") - # Payload should be rendered directly in the command (via Jinja2) - assert test_payload in shell_command + # Command should contain the {payload} placeholder (filled in at call time) + assert "{payload}" in shell_command def test_includes_uv_commands(self, tmp_path): """Test that the shell command uses uv for env creation.""" @@ -239,7 +235,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "test_payload") + shell_command = template_shell_command(str(script_path), "func") # Check for uv installation assert "uv.find_uv_bin()" in shell_command @@ -262,9 +258,7 @@ def dict_func(): """ script_path.write_text(script_content) - shell_command = template_shell_command( - str(script_path), "dict_func", "test_payload" - ) + shell_command = template_shell_command(str(script_path), "dict_func") # Curly braces in user code should be doubled (escaped via Jinja2 filter) # This is needed because Globus Compute's ShellFunction calls .format() @@ -293,20 +287,18 @@ def use_torch(): """ script_path.write_text(script_content) - shell_command = template_shell_command( - str(script_path), "use_torch", "test_payload" - ) + shell_command = template_shell_command(str(script_path), "use_torch") # Simulate what Globus Compute's ShellFunction does: - # It calls .format() on the command (without any kwargs) + # It calls .format(payload=...) on the command try: # This should not raise KeyError if curly braces are properly escaped - formatted = shell_command.format() + formatted = shell_command.format(payload="test_payload") # After .format(), the doubled braces should become single braces assert '{"torch"' in formatted except KeyError as e: pytest.fail( - f"shell_command.format() raised KeyError: {e}. " + f"shell_command.format(payload=...) raised KeyError: {e}. " "This means curly braces in user code are not properly escaped!" ) @@ -338,8 +330,8 @@ def func2(): script1_path.write_text(script1_content) script2_path.write_text(script2_content) - command1 = template_shell_command(str(script1_path), "func1", "test_payload") - command2 = template_shell_command(str(script2_path), "func2", "test_payload") + command1 = template_shell_command(str(script1_path), "func1") + command2 = template_shell_command(str(script2_path), "func2") # Extract the script names (format: basename-hash) # They should have different hashes since content differs @@ -367,7 +359,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "test_payload") + shell_command = template_shell_command(str(script_path), "func") # Should include the package-specific exclude-newer override assert "--exclude-newer-package groundhog-hpc=" in shell_command @@ -559,7 +551,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "ENV_HASH=" in shell_command @@ -579,7 +571,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "groundhog-envs" in shell_command assert "ENV_DIR=" in shell_command @@ -600,7 +592,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert 'if [ -d "$ENV_DIR" ]' in shell_command assert '"$UV_BIN" venv' in shell_command @@ -622,7 +614,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert '"$ENV_DIR/bin/python"' in shell_command assert '"$UV_BIN" run' not in shell_command @@ -643,7 +635,7 @@ def func(): """ script_path.write_text(script_content) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "groundhog-meta.json" in shell_command assert '"requires_python":' in shell_command @@ -665,7 +657,7 @@ def func(): script_path.write_text(script_content) with caplog.at_level(logging.WARNING): - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "ENV_HASH=" in shell_command assert any( @@ -848,7 +840,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert '"$ENV_DIR/uv.toml"' in shell_command assert 'exclude-newer = "2025-01-01T00:00:00Z"' in shell_command @@ -872,7 +864,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert '--config-file "$ENV_DIR/uv.toml"' in shell_command @@ -894,7 +886,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") # --exclude-newer as a standalone CLI flag should be gone import re @@ -923,7 +915,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") # uv venv line should carry --config-file venv_line = next( @@ -951,7 +943,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") toml_write_pos = shell_command.find("UV_CONFIG_EOF") venv_pos = shell_command.find('"$UV_BIN" venv') @@ -971,7 +963,7 @@ def func(): return 1 """) - shell_command = template_shell_command(str(script_path), "func", "payload") + shell_command = template_shell_command(str(script_path), "func") assert "UV_CONFIG_EOF" not in shell_command assert "--config-file" not in shell_command @@ -997,9 +989,130 @@ def compute(x): result = template_shell_command( str(script_path), "MyClass.compute", # Dotted qualname - "[[1], {}]", ) # The runner should use attrgetter for dotted paths assert "attrgetter" in result assert "MyClass.compute" in result + + +MINIMAL_SCRIPT = """\ +# /// script +# requires-python = ">=3.12" +# dependencies = [] +# /// + +import groundhog_hpc as hog + +@hog.function() +def func(): + return 42 +""" + + +class TestTemplateShellCommandParameterized: + """Tests for the parameterized shell command template.""" + + def _write_script(self, tmp_path, content=MINIMAL_SCRIPT): + p = tmp_path / "script.py" + p.write_text(content) + return str(p) + + def test_returns_a_string(self, tmp_path): + script_path = self._write_script(tmp_path) + result = template_shell_command(script_path, "func") + assert isinstance(result, str) + assert len(result) > 0 + + def test_contains_payload_placeholder_exactly_once(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func") + assert cmd.count("{payload}") == 1 + + def test_format_with_payload_kwarg_substitutes_correctly(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func") + result = cmd.format(payload="__PICKLE__:AAAA==") + assert "__PICKLE__:AAAA==" in result + assert "{payload}" not in result + + def test_format_without_payload_kwarg_raises_key_error(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func") + with pytest.raises(KeyError): + cmd.format() + + def test_base64_payload_is_format_safe(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func") + base64_payload = "__PICKLE__:ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/==" + result = cmd.format(payload=base64_payload) + assert base64_payload in result + + def test_user_code_braces_are_escaped_before_format_call(self, tmp_path): + """Dict literals in user code survive .format(payload=...) without KeyError.""" + script_content = """\ +# /// script +# requires-python = ">=3.12" +# dependencies = [] +# /// + +import groundhog_hpc as hog + +@hog.function() +def func(): + return {"key": "value"} +""" + script_path = self._write_script(tmp_path, script_content) + cmd = template_shell_command(script_path, "func") + # Dict braces must be doubled in cmd so .format() doesn't raise KeyError + assert '{{"key": "value"}}' in cmd + # After .format(), doubled braces collapse to single braces (dict literal preserved) + result = cmd.format(payload="test") + assert '{"key": "value"}' in result + + def test_contains_mktemp_for_file_isolation(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func") + assert "mktemp -d" in cmd + + def test_cleanup_uses_rm_rf_task_dir(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func") + assert 'rm -rf "$TASK_DIR"' in cmd + # Individual file cleanup should not appear + assert "rm -f " not in cmd + + def test_file_paths_use_fixed_names_inside_task_dir(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func") + assert "$TASK_DIR/user_script.py" in cmd + assert "$TASK_DIR/runner.py" in cmd + assert "$TASK_DIR/payload.in" in cmd + # No random UUID suffixes in paths + import re + + assert not re.search(r"\w+-[0-9a-f]{8}-[0-9a-f]{8}\.py", cmd) + + def test_runner_references_fixed_payload_path(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func") + assert "open('payload.in'" in cmd + + def test_includes_standard_uv_and_env_reuse_infrastructure(self, tmp_path): + script_path = self._write_script(tmp_path) + cmd = template_shell_command(script_path, "func") + assert "ENV_HASH=" in cmd + assert "ENV_DIR=" in cmd + assert '"$UV_BIN" venv' in cmd + assert '"$UV_BIN" pip install' in cmd + assert '"$ENV_DIR/bin/python"' in cmd + + def test_different_scripts_produce_different_commands(self, tmp_path): + script1 = tmp_path / "script1.py" + script2 = tmp_path / "script2.py" + script1.write_text(MINIMAL_SCRIPT) + script2.write_text(MINIMAL_SCRIPT.replace("return 42", "return 99")) + cmd1 = template_shell_command(str(script1), "func") + cmd2 = template_shell_command(str(script2), "func") + assert cmd1 != cmd2