Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions src/groundhog_hpc/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,18 @@ def submit_to_executor(
"""
import globus_compute_sdk as gc

# Validate config against endpoint schema and filter out unexpected keys
# Validate config against endpoint schema and filter out unexpected keys,
# but only when the schema explicitly forbids extra keys.
config = user_endpoint_config.copy()
if schema := get_endpoint_schema(endpoint):
expected_keys = set(schema.get("properties", {}).keys())
unexpected_keys = set(config.keys()) - expected_keys
if unexpected_keys:
logger.debug(
f"Filtering unexpected config keys for endpoint {endpoint}: {unexpected_keys}"
)
config = {k: v for k, v in config.items() if k not in unexpected_keys}
if schema.get("additionalProperties") is False:
expected_keys = set(schema.get("properties", {}).keys())
unexpected_keys = set(config.keys()) - expected_keys
if unexpected_keys:
logger.debug(
f"Filtering unexpected config keys for endpoint {endpoint}: {unexpected_keys}"
)
config = {k: v for k, v in config.items() if k not in unexpected_keys}

logger.debug(f"Creating Globus Compute executor for endpoint {endpoint}")
with gc.Executor(
Expand Down Expand Up @@ -141,13 +143,14 @@ def submit_batch(

config = user_endpoint_config.copy()
if schema := get_endpoint_schema(endpoint):
expected_keys = set(schema.get("properties", {}).keys())
unexpected_keys = set(config.keys()) - expected_keys
if unexpected_keys:
logger.debug(
f"Filtering unexpected config keys for endpoint {endpoint}: {unexpected_keys}"
)
config = {k: v for k, v in config.items() if k not in unexpected_keys}
if schema.get("additionalProperties") is False:
expected_keys = set(schema.get("properties", {}).keys())
unexpected_keys = set(config.keys()) - expected_keys
if unexpected_keys:
logger.debug(
f"Filtering unexpected config keys for endpoint {endpoint}: {unexpected_keys}"
)
config = {k: v for k, v in config.items() if k not in unexpected_keys}

func_name = getattr(shell_function, "__name__", "unknown")
function_id = client.register_function(shell_function)
Expand Down
11 changes: 11 additions & 0 deletions src/groundhog_hpc/configuration/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ def _merge_variant_path(

config = _merge_variant_path(variant_path, base_variant, config)

# Strip nested dicts that are leftover variant configs (e.g., if the
# user targets "polaris", the "polaris.gpu" sub-dict shouldn't leak
# into the config sent to the API). At this point, any dict values
# are untraversed variant sub-configs, not legitimate endpoint settings.
variant_keys = [k for k, v in config.items() if isinstance(v, dict)]
if variant_keys:
logger.debug(
f"Stripping variant sub-configs from resolved config: {variant_keys}"
)
config = {k: v for k, v in config.items() if not isinstance(v, dict)}

# Layer 4: Merge decorator config
if decorator_config:
logger.debug(f"Merging decorator config: {decorator_config}")
Expand Down
5 changes: 4 additions & 1 deletion tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def test_endpoint_schema_filtering_applied(self, mock_globus_client):
client = _make_batch_client(task_ids=["tid-0"])
mock_globus_client.return_value = client

schema = {"properties": {"account": {"type": "string"}}}
schema = {
"properties": {"account": {"type": "string"}},
"additionalProperties": False,
}
with patch("groundhog_hpc.compute.get_endpoint_schema", return_value=schema):
submit_batch(
_ENDPOINT,
Expand Down
5 changes: 2 additions & 3 deletions tests/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,8 @@ def test_dict_valued_keys_preserved_from_decorator_and_callsite(
"OMP_NUM_THREADS": "4",
}

# PEP 723 variant is also included (will be filtered at submit time)
assert "gpu" in result
assert isinstance(result["gpu"], dict)
# PEP 723 variant sub-dicts are stripped after resolution
assert "gpu" not in result


class TestConfigResolverPep723Variants:
Expand Down
Loading