Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dvc.exceptions import DvcException
from dvc.log import logger
from dvc.parsing.interpolate import ParseError
from dvc.stage.params import StageParams
from dvc.utils.objects import cached_property

from .context import (
Expand Down Expand Up @@ -518,6 +519,7 @@ def __init__(

self._template = definition.copy()
self.matrix_data = self._template.pop(MATRIX_KWD)
self._custom_name = self._template.pop(StageParams.PARAM_NAME, None)

self.pair = IterationPair()
self.where = where
Expand Down Expand Up @@ -577,19 +579,94 @@ def normalized_iterable(self) -> dict[str, "DictStrAny"]:
return ret

def has_member(self, key: str) -> bool:
return key in self.normalized_iterable
resolved = self._resolve_iteration_key(key)
return resolved in self.normalized_iterable

def get_generated_names(self) -> list[str]:
return list(map(self._generate_name, self.normalized_iterable))

def _generate_name(self, key: str) -> str:
return f"{self.name}{JOIN}{key}"
if not self._custom_name:
return f"{self.name}{JOIN}{key}"

suffix = self._custom_suffixes[key]
return f"{self.name}{JOIN}{suffix}"

def resolve_all(self) -> "DictStrAny":
return join(map(self.resolve_one, self.normalized_iterable))

def resolve_one(self, key: str) -> "DictStrAny":
return self._each_iter(key)
resolved = self._resolve_iteration_key(key)
return self._each_iter(resolved)

def _resolve_iteration_key(self, key: str) -> str:
if not self._custom_name or key in self.normalized_iterable:
return key

for original, suffix in self._custom_suffixes.items():
if suffix == key:
return original

return key

def _render_custom_suffix(self, key: str) -> str:
value = self.normalized_iterable[key]
temp_dict = {self.pair.key: key, self.pair.value: value}
with self.context.set_temporarily(temp_dict, reserve=True):
try:
resolved = self.context.resolve_str(self._custom_name)
except (ContextError, ParseError) as exc:
format_and_raise(
exc,
f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'",
self.relpath,
)

if not isinstance(resolved, (str, int, float, bool)):
format_and_raise(
ResolveError(
"matrix stage name must resolve to a string or primitive value"
),
f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'",
self.relpath,
)

suffix = to_str(resolved)
if not suffix:
format_and_raise(
ResolveError("matrix stage name cannot be empty"),
f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'",
self.relpath,
)

if JOIN in suffix:
format_and_raise(
ResolveError(f"matrix stage name cannot contain '{JOIN}'"),
f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'",
self.relpath,
)

return suffix

@cached_property
def _custom_suffixes(self) -> dict[str, str]:
if not self._custom_name:
return {}

seen: set[str] = set()
suffixes: dict[str, str] = {}
for key in self.normalized_iterable:
suffix = self._render_custom_suffix(key)
if suffix in seen:
format_and_raise(
ResolveError(f"matrix stage name '{suffix}' is already defined"),
f"'{self.where}.{self.name}.{StageParams.PARAM_NAME}'",
self.relpath,
)
suffixes[key] = suffix
seen.add(suffix)

return suffixes

def _each_iter(self, key: str) -> "DictStrAny":
err_message = f"Could not find '{key}' in matrix group '{self.name}'"
Expand Down
1 change: 1 addition & 0 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
vol.Optional(StageParams.PARAM_OUTS): [vol.Any(str, OUT_PSTAGE_DETAILED_SCHEMA)],
vol.Optional(StageParams.PARAM_METRICS): [vol.Any(str, OUT_PSTAGE_DETAILED_SCHEMA)],
vol.Optional(StageParams.PARAM_PLOTS): [vol.Any(str, PLOT_PSTAGE_SCHEMA)],
vol.Optional(StageParams.PARAM_NAME): str,
}


Expand Down
1 change: 1 addition & 0 deletions dvc/stage/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class StageParams:
PARAM_METRICS = "metrics"
PARAM_PLOTS = "plots"
PARAM_DESC = "desc"
PARAM_NAME = "name"
50 changes: 49 additions & 1 deletion tests/func/parsing/test_matrix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from dvc.parsing import DataResolver, MatrixDefinition
from dvc.parsing import DataResolver, MatrixDefinition, ResolveError

MATRIX_DATA = {
"os": ["win", "linux"],
Expand Down Expand Up @@ -91,3 +91,51 @@ def test_matrix_key_present(tmp_dir, dvc, matrix):
"build@linux-3.8-dict1-list0": {"cmd": "echo linux-3.8-dict1-list0"},
"build@linux-3.8-dict1-list1": {"cmd": "echo linux-3.8-dict1-list1"},
}


def test_matrix_custom_name(tmp_dir, dvc):
matrix = {
"dataset": [{"key": "dataset_a"}],
"model": [{"key": "model_alpha"}],
}
resolver = DataResolver(dvc, tmp_dir.fs_path, {})
data = {
"matrix": matrix,
"name": "${item.model.key}_${item.dataset.key}",
"cmd": "echo ${item.model.key} ${item.dataset.key}",
}
definition = MatrixDefinition(resolver, resolver.context, "inference", data)

assert definition.get_generated_names() == ["inference@model_alpha_dataset_a"]
assert definition.has_member("model_alpha_dataset_a")
assert definition.resolve_one("model_alpha_dataset_a") == {
"inference@model_alpha_dataset_a": {"cmd": "echo model_alpha dataset_a"}
}


def test_matrix_custom_name_duplicate_error(tmp_dir, dvc):
matrix = {"model": [{"key": "same"}, {"key": "same"}]}
resolver = DataResolver(dvc, tmp_dir.fs_path, {})
data = {
"matrix": matrix,
"name": "${item.model.key}",
"cmd": "echo ${item.model.key}",
}
definition = MatrixDefinition(resolver, resolver.context, "train", data)

with pytest.raises(ResolveError, match="already defined"):
definition.get_generated_names()


def test_matrix_custom_name_invalid_suffix(tmp_dir, dvc):
matrix = {"model": [{"key": "same"}]}
resolver = DataResolver(dvc, tmp_dir.fs_path, {})
data = {
"matrix": matrix,
"name": "bad@name",
"cmd": "echo ${item.model.key}",
}
definition = MatrixDefinition(resolver, resolver.context, "train", data)

with pytest.raises(ResolveError, match="cannot contain"):
definition.get_generated_names()