diff --git a/src/taskgraph/optimize/base.py b/src/taskgraph/optimize/base.py index ba07ca276..0d333c624 100644 --- a/src/taskgraph/optimize/base.py +++ b/src/taskgraph/optimize/base.py @@ -22,13 +22,14 @@ from taskgraph.taskgraph import TaskGraph from taskgraph.util.parameterization import resolve_task_references, resolve_timestamps from taskgraph.util.python_path import import_sibling_modules +from taskgraph.util.schema import validate_schema from taskgraph.util.taskcluster import find_task_id_batched, status_task_batched logger = logging.getLogger("optimization") registry = {} -def register_strategy(name, args=(), kwargs=None): +def register_strategy(name, args=(), kwargs=None, schema=None): kwargs = kwargs or {} def wrap(cls): @@ -36,6 +37,7 @@ def wrap(cls): registry[name] = cls(*args, **kwargs) if not hasattr(registry[name], "description"): registry[name].description = name + registry[name].schema = schema return cls return wrap @@ -123,6 +125,13 @@ def optimizations(label): if task.optimization: opt_by, arg = list(task.optimization.items())[0] strategy = strategies[opt_by] + schema = getattr(strategy, "schema", None) + if schema: + validate_schema( + schema, + arg, + f"In task `{label}` optimization `{opt_by}`:", + ) if hasattr(strategy, "description"): opt_by += f" ({strategy.description})" return (opt_by, strategy, arg) diff --git a/src/taskgraph/optimize/strategies.py b/src/taskgraph/optimize/strategies.py index 8fed9e54a..7b461fe2a 100644 --- a/src/taskgraph/optimize/strategies.py +++ b/src/taskgraph/optimize/strategies.py @@ -5,12 +5,13 @@ from taskgraph.optimize.base import OptimizationStrategy, register_strategy from taskgraph.util.path import match as match_path +from taskgraph.util.schema import Schema from taskgraph.util.taskcluster import find_task_id, status_task logger = logging.getLogger("optimization") -@register_strategy("index-search") +@register_strategy("index-search", schema=Schema([str])) class IndexSearch(OptimizationStrategy): # A task with no dependencies remaining after optimization will be replaced # if artifacts exist for the corresponding index_paths. @@ -73,7 +74,7 @@ def should_replace_task(self, task, params, deadline, arg): return False -@register_strategy("skip-unless-changed") +@register_strategy("skip-unless-changed", schema=Schema([str])) class SkipUnlessChanged(OptimizationStrategy): def check(self, files_changed, patterns): for pattern in patterns: diff --git a/src/taskgraph/transforms/task.py b/src/taskgraph/transforms/task.py index 7c834dcc7..e6dc127cd 100644 --- a/src/taskgraph/transforms/task.py +++ b/src/taskgraph/transforms/task.py @@ -23,7 +23,6 @@ from taskgraph.util.hash import hash_path from taskgraph.util.keyed_by import evaluate_keyed_by from taskgraph.util.schema import ( - OptimizationSchema, Schema, optionally_keyed_by, resolve_keyed_by, @@ -340,10 +339,10 @@ def run_task_suffix(): description=dedent( """ Optimization to perform on this task during the optimization - phase. Defined in taskcluster/taskgraph/optimize.py. + phase. The schema for this value is specific to the given optimization. """.lstrip() ), - ): OptimizationSchema, + ): Any(None, dict), Required( "worker-type", description=dedent( diff --git a/src/taskgraph/util/schema.py b/src/taskgraph/util/schema.py index 3c5f4c955..f6d2ee9d3 100644 --- a/src/taskgraph/util/schema.py +++ b/src/taskgraph/util/schema.py @@ -230,16 +230,6 @@ def __getitem__(self, item): return self.schema[item] # type: ignore -OptimizationSchema = voluptuous.Any( - # always run this task (default) - None, - # search the index for the given index namespaces, and replace this task if found - # the search occurs in order, with the first match winning - {"index-search": [str]}, - # skip this task if none of the given file patterns match - {"skip-unless-changed": [str]}, -) - # shortcut for a string where task references are allowed taskref_or_string = voluptuous.Any( str, diff --git a/test/test_optimize.py b/test/test_optimize.py index bfc2e9709..07c2a6a9b 100644 --- a/test/test_optimize.py +++ b/test/test_optimize.py @@ -7,6 +7,7 @@ import pytest from pytest_taskgraph import make_graph, make_task +from voluptuous import Schema from taskgraph.graph import Graph from taskgraph.optimize import base as optimize_mod @@ -487,3 +488,15 @@ def test_register_strategy(mocker): func = register_strategy("foo", args=("one", "two"), kwargs={"n": 1}) func(m) m.assert_called_with("one", "two", n=1) + + +def test_register_strategy_with_schema(mocker, monkeypatch): + monkeypatch.setattr(optimize_mod, "registry", {}) + schema = Schema([str]) + + @register_strategy("bar", schema=schema) + class TestStrategy(OptimizationStrategy): + pass + + assert "bar" in optimize_mod.registry + assert optimize_mod.registry["bar"].schema is schema