From dc35e27a0ad140aeef63dec5ab20346e8e01e415 Mon Sep 17 00:00:00 2001 From: Kylian Ronfleux--Corail <35237015+Kyliroco@users.noreply.github.com> Date: Thu, 30 Oct 2025 15:49:29 +0100 Subject: [PATCH 1/2] Allow naming matrix expansions --- dvc/parsing/__init__.py | 83 +++++++++++++++++++++++++++++-- dvc/schema.py | 1 + dvc/stage/params.py | 1 + tests/func/parsing/test_matrix.py | 52 ++++++++++++++++++- 4 files changed, 133 insertions(+), 4 deletions(-) diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index a0a65d59a2..b32e7f2d57 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -10,6 +10,7 @@ from dvc.exceptions import DvcException from dvc.log import logger from dvc.parsing.interpolate import ParseError +from dvc.stage.params import StageParams from dvc.utils.objects import cached_property from .context import ( @@ -518,6 +519,7 @@ def __init__( self._template = definition.copy() self.matrix_data = self._template.pop(MATRIX_KWD) + self._custom_name = self._template.pop(StageParams.PARAM_NAME, None) self.pair = IterationPair() self.where = where @@ -577,19 +579,94 @@ def normalized_iterable(self) -> dict[str, "DictStrAny"]: return ret def has_member(self, key: str) -> bool: - return key in self.normalized_iterable + resolved = self._resolve_iteration_key(key) + return resolved in self.normalized_iterable def get_generated_names(self) -> list[str]: return list(map(self._generate_name, self.normalized_iterable)) def _generate_name(self, key: str) -> str: - return f"{self.name}{JOIN}{key}" + if not self._custom_name: + return f"{self.name}{JOIN}{key}" + + suffix = self._custom_suffixes[key] + return f"{self.name}{JOIN}{suffix}" def resolve_all(self) -> "DictStrAny": return join(map(self.resolve_one, self.normalized_iterable)) def resolve_one(self, key: str) -> "DictStrAny": - return self._each_iter(key) + resolved = self._resolve_iteration_key(key) + return self._each_iter(resolved) + + def _resolve_iteration_key(self, key: str) -> str: + if not self._custom_name or key in self.normalized_iterable: + return key + + for original, suffix in self._custom_suffixes.items(): + if suffix == key: + return original + + return key + + def _render_custom_suffix(self, key: str) -> str: + value = self.normalized_iterable[key] + temp_dict = {self.pair.key: key, self.pair.value: value} + with self.context.set_temporarily(temp_dict, reserve=True): + try: + resolved = self.context.resolve_str(self._custom_name) + except (ContextError, ParseError) as exc: + format_and_raise( + exc, + f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'", + self.relpath, + ) + + if not isinstance(resolved, (str, int, float, bool)): + format_and_raise( + ResolveError( + "matrix stage name must resolve to a string or primitive value" + ), + f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'", + self.relpath, + ) + + suffix = to_str(resolved) + if not suffix: + format_and_raise( + ResolveError("matrix stage name cannot be empty"), + f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'", + self.relpath, + ) + + if JOIN in suffix: + format_and_raise( + ResolveError(f"matrix stage name cannot contain '{JOIN}'"), + f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'", + self.relpath, + ) + + return suffix + + @cached_property + def _custom_suffixes(self) -> dict[str, str]: + if not self._custom_name: + return {} + + seen: set[str] = set() + suffixes: dict[str, str] = {} + for key in self.normalized_iterable: + suffix = self._render_custom_suffix(key) + if suffix in seen: + format_and_raise( + ResolveError(f"matrix stage name '{suffix}' is already defined"), + f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'", + self.relpath, + ) + suffixes[key] = suffix + seen.add(suffix) + + return suffixes def _each_iter(self, key: str) -> "DictStrAny": err_message = f"Could not find '{key}' in matrix group '{self.name}'" diff --git a/dvc/schema.py b/dvc/schema.py index 85001016fc..76d8ce4cea 100644 --- a/dvc/schema.py +++ b/dvc/schema.py @@ -93,6 +93,7 @@ vol.Optional(StageParams.PARAM_OUTS): [vol.Any(str, OUT_PSTAGE_DETAILED_SCHEMA)], vol.Optional(StageParams.PARAM_METRICS): [vol.Any(str, OUT_PSTAGE_DETAILED_SCHEMA)], vol.Optional(StageParams.PARAM_PLOTS): [vol.Any(str, PLOT_PSTAGE_SCHEMA)], + vol.Optional(StageParams.PARAM_NAME): str, } diff --git a/dvc/stage/params.py b/dvc/stage/params.py index 14ca18b250..cdb60f453e 100644 --- a/dvc/stage/params.py +++ b/dvc/stage/params.py @@ -12,3 +12,4 @@ class StageParams: PARAM_METRICS = "metrics" PARAM_PLOTS = "plots" PARAM_DESC = "desc" + PARAM_NAME = "name" diff --git a/tests/func/parsing/test_matrix.py b/tests/func/parsing/test_matrix.py index 26ffb527b1..4b9bcf1811 100644 --- a/tests/func/parsing/test_matrix.py +++ b/tests/func/parsing/test_matrix.py @@ -1,6 +1,6 @@ import pytest -from dvc.parsing import DataResolver, MatrixDefinition +from dvc.parsing import DataResolver, MatrixDefinition, ResolveError MATRIX_DATA = { "os": ["win", "linux"], @@ -91,3 +91,53 @@ def test_matrix_key_present(tmp_dir, dvc, matrix): "build@linux-3.8-dict1-list0": {"cmd": "echo linux-3.8-dict1-list0"}, "build@linux-3.8-dict1-list1": {"cmd": "echo linux-3.8-dict1-list1"}, } + + +def test_matrix_custom_name(tmp_dir, dvc): + matrix = { + "dataset": [{"key": "dataset_a"}], + "model": [{"key": "model_alpha"}], + } + resolver = DataResolver(dvc, tmp_dir.fs_path, {}) + data = { + "matrix": matrix, + "name": "${item.model.key}_${item.dataset.key}", + "cmd": "echo ${item.model.key} ${item.dataset.key}", + } + definition = MatrixDefinition(resolver, resolver.context, "inference", data) + + assert definition.get_generated_names() == ["inference@model_alpha_dataset_a"] + assert definition.has_member("model_alpha_dataset_a") + assert definition.resolve_one("model_alpha_dataset_a") == { + "inference@model_alpha_dataset_a": { + "cmd": "echo model_alpha dataset_a" + } + } + + +def test_matrix_custom_name_duplicate_error(tmp_dir, dvc): + matrix = {"model": [{"key": "same"}, {"key": "same"}]} + resolver = DataResolver(dvc, tmp_dir.fs_path, {}) + data = { + "matrix": matrix, + "name": "${item.model.key}", + "cmd": "echo ${item.model.key}", + } + definition = MatrixDefinition(resolver, resolver.context, "train", data) + + with pytest.raises(ResolveError, match="already defined"): + definition.get_generated_names() + + +def test_matrix_custom_name_invalid_suffix(tmp_dir, dvc): + matrix = {"model": [{"key": "same"}]} + resolver = DataResolver(dvc, tmp_dir.fs_path, {}) + data = { + "matrix": matrix, + "name": "bad@name", + "cmd": "echo ${item.model.key}", + } + definition = MatrixDefinition(resolver, resolver.context, "train", data) + + with pytest.raises(ResolveError, match="cannot contain"): + definition.get_generated_names() From 8fcad6e190aeb991c470e1216d6e700f3bb5412e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Oct 2025 15:48:53 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/func/parsing/test_matrix.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/func/parsing/test_matrix.py b/tests/func/parsing/test_matrix.py index 4b9bcf1811..de889d451d 100644 --- a/tests/func/parsing/test_matrix.py +++ b/tests/func/parsing/test_matrix.py @@ -109,9 +109,7 @@ def test_matrix_custom_name(tmp_dir, dvc): assert definition.get_generated_names() == ["inference@model_alpha_dataset_a"] assert definition.has_member("model_alpha_dataset_a") assert definition.resolve_one("model_alpha_dataset_a") == { - "inference@model_alpha_dataset_a": { - "cmd": "echo model_alpha dataset_a" - } + "inference@model_alpha_dataset_a": {"cmd": "echo model_alpha dataset_a"} }