diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 5ff336c18..f6a55d84f 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -20,6 +20,7 @@ import json import logging import sys +from collections.abc import Callable from pathlib import Path from typing import TYPE_CHECKING, Any, Optional @@ -895,36 +896,51 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: "max_dataset_size": None, } + def _collect_list_values( + *, + flag: str, + start_index: int, + parse_value: Optional[Callable[[str], Any]] = None, + stop_on_short_strategy_flag: bool = False, + ) -> tuple[list[Any], int]: + values: list[Any] = [] + index = start_index + 1 + + while index < len(parts) and not parts[index].startswith("--"): + if stop_on_short_strategy_flag and parts[index] == "-s": + break + values.append(parse_value(parts[index]) if parse_value else parts[index]) + index += 1 + + if not values: + raise ValueError(f"{flag} requires at least one value") + + return values, index + i = 1 while i < len(parts): if parts[i] == "--initializers": - # Collect initializers until next flag, parsing name:key=val syntax - result["initializers"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initializers"].append(_parse_initializer_arg(parts[i])) - i += 1 + result["initializers"], i = _collect_list_values( + flag="--initializers", + start_index=i, + parse_value=_parse_initializer_arg, + ) elif parts[i] == "--initialization-scripts": - # Collect script paths until next flag - result["initialization_scripts"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initialization_scripts"].append(parts[i]) - i += 1 + result["initialization_scripts"], i = _collect_list_values( + flag="--initialization-scripts", + start_index=i, + ) elif parts[i] == "--env-files": - # Collect env file paths until next flag - result["env_files"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["env_files"].append(parts[i]) - i += 1 + result["env_files"], i = _collect_list_values( + flag="--env-files", + start_index=i, + ) elif parts[i] in ("--strategies", "-s"): - # Collect strategies until next flag - result["scenario_strategies"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--") and parts[i] != "-s": - result["scenario_strategies"].append(parts[i]) - i += 1 + result["scenario_strategies"], i = _collect_list_values( + flag=parts[i], + start_index=i, + stop_on_short_strategy_flag=True, + ) elif parts[i] == "--max-concurrency": i += 1 if i >= len(parts): @@ -956,12 +972,10 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: result["log_level"] = validate_log_level(log_level=parts[i]) i += 1 elif parts[i] == "--dataset-names": - # Collect dataset names until next flag - result["dataset_names"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["dataset_names"].append(parts[i]) - i += 1 + result["dataset_names"], i = _collect_list_values( + flag="--dataset-names", + start_index=i, + ) elif parts[i] == "--max-dataset-size": i += 1 if i >= len(parts): diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 7c040deb5..6042a393b 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -684,6 +684,22 @@ def test_parse_run_arguments_missing_value(self): with pytest.raises(ValueError, match="requires a value"): frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency") + @pytest.mark.parametrize( + ("args_string", "flag"), + [ + ("test_scenario --initializers --max-retries 2", "--initializers"), + ("test_scenario --initialization-scripts --max-retries 2", "--initialization-scripts"), + ("test_scenario --env-files --database InMemory", "--env-files"), + ("test_scenario --strategies --max-retries 2", "--strategies"), + ("test_scenario -s --max-retries 2", "-s"), + ("test_scenario --dataset-names --max-dataset-size 5", "--dataset-names"), + ], + ) + def test_parse_run_arguments_list_flags_require_values(self, args_string: str, flag: str): + """Test list-style arguments require at least one value.""" + with pytest.raises(ValueError, match=rf"{flag} requires at least one value"): + frontend_core.parse_run_arguments(args_string=args_string) + @pytest.mark.asyncio @pytest.mark.usefixtures("patch_central_database")