diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index d96f7a7ca..279fa8988 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.path import DEFAULT_CONFIG_PATH +from pyrit.common.utils import verify_and_resolve_path from pyrit.common.yaml_loadable import YamlLoadable from pyrit.identifiers.class_name_utils import class_name_to_snake_case from pyrit.setup.initialization import ( @@ -100,6 +101,8 @@ class ConfigurationLoader(YamlLoadable): silent: bool = False operator: Optional[str] = None operation: Optional[str] = None + _initialization_scripts_base_path: Optional[pathlib.Path] = field(default=None, init=False, repr=False) + _env_files_base_path: Optional[pathlib.Path] = field(default=None, init=False, repr=False) def __post_init__(self) -> None: """Validate and normalize the configuration after loading.""" @@ -179,6 +182,48 @@ def from_dict(cls, data: dict[str, Any]) -> "ConfigurationLoader": filtered_data = {k: v for k, v in data.items() if v is not None} return cls(**filtered_data) + @classmethod + def from_yaml_file(cls, file: pathlib.Path | str) -> "ConfigurationLoader": + """ + Create a ConfigurationLoader from a YAML file and preserve its base directory. + + Relative initialization script and env file paths should resolve from the + configuration file directory rather than the caller's working directory. + + Returns: + A new ConfigurationLoader instance with per-field path resolution bases. + """ + resolved_file = verify_and_resolve_path(file) + config = YamlLoadable.from_yaml_file.__func__(cls, resolved_file) + config._set_path_resolution_base_paths( + initialization_scripts_base_path=resolved_file.parent, + env_files_base_path=resolved_file.parent, + ) + return config + + def _set_path_resolution_base_paths( + self, + *, + initialization_scripts_base_path: Optional[pathlib.Path], + env_files_base_path: Optional[pathlib.Path], + ) -> None: + """Set per-field base paths for resolving relative configuration paths.""" + self._initialization_scripts_base_path = initialization_scripts_base_path + self._env_files_base_path = env_files_base_path + + @staticmethod + def _resolve_config_path(path_str: str, base_path: Optional[pathlib.Path]) -> pathlib.Path: + """ + Resolve config-provided relative paths against an optional base directory. + + Returns: + An absolute path when a relative base is available, or the original absolute path. + """ + config_path = pathlib.Path(path_str) + if config_path.is_absolute(): + return config_path + return (base_path or pathlib.Path.cwd()) / config_path + @staticmethod def load_with_overrides( config_file: Optional[pathlib.Path] = None, @@ -217,6 +262,8 @@ def load_with_overrides( import logging logger = logging.getLogger(__name__) + initialization_scripts_base_path: Optional[pathlib.Path] = None + env_files_base_path: Optional[pathlib.Path] = None # Start with defaults - None means "use defaults", [] means "load nothing" config_data: dict[str, Any] = { @@ -239,6 +286,8 @@ def load_with_overrides( # Preserve None vs [] distinction from config file config_data["initialization_scripts"] = default_config.initialization_scripts config_data["env_files"] = default_config.env_files + initialization_scripts_base_path = default_config._initialization_scripts_base_path + env_files_base_path = default_config._env_files_base_path if default_config.operator: config_data["operator"] = default_config.operator if default_config.operation: @@ -259,6 +308,8 @@ def load_with_overrides( # Preserve None vs [] distinction from config file config_data["initialization_scripts"] = explicit_config.initialization_scripts config_data["env_files"] = explicit_config.env_files + initialization_scripts_base_path = explicit_config._initialization_scripts_base_path + env_files_base_path = explicit_config._env_files_base_path if explicit_config.operator: config_data["operator"] = explicit_config.operator if explicit_config.operation: @@ -280,11 +331,18 @@ def load_with_overrides( if initialization_scripts is not None: config_data["initialization_scripts"] = list(initialization_scripts) + initialization_scripts_base_path = None if env_files is not None: config_data["env_files"] = list(env_files) + env_files_base_path = None - return ConfigurationLoader.from_dict(config_data) + config = ConfigurationLoader.from_dict(config_data) + config._set_path_resolution_base_paths( + initialization_scripts_base_path=initialization_scripts_base_path, + env_files_base_path=env_files_base_path, + ) + return config @classmethod def get_default_config_path(cls) -> pathlib.Path: @@ -325,8 +383,12 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: f"Initializer '{config.name}' not found in registry.\nAvailable initializers: {available}" ) - # Instantiate with args if provided - instance = initializer_class(**config.args) if config.args else initializer_class() + # Instantiate and set params if provided + instance = initializer_class() + if config.args: + instance.set_params_from_args(args=config.args) + # Validate params early against supported_parameters to fail fast + instance._validate_params(params=instance.params) resolved.append(instance) @@ -348,14 +410,10 @@ def _resolve_initialization_scripts(self) -> Optional[Sequence[pathlib.Path]]: if len(self.initialization_scripts) == 0: return [] - resolved: list[pathlib.Path] = [] - for script_str in self.initialization_scripts: - script_path = pathlib.Path(script_str) - if not script_path.is_absolute(): - script_path = pathlib.Path.cwd() / script_path - resolved.append(script_path) - - return resolved + return [ + self._resolve_config_path(script_str, self._initialization_scripts_base_path) + for script_str in self.initialization_scripts + ] def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: """ @@ -373,14 +431,10 @@ def _resolve_env_files(self) -> Optional[Sequence[pathlib.Path]]: if len(self.env_files) == 0: return [] - resolved: list[pathlib.Path] = [] - for env_str in self.env_files: - env_path = pathlib.Path(env_str) - if not env_path.is_absolute(): - env_path = pathlib.Path.cwd() / env_path - resolved.append(env_path) - - return resolved + return [ + self._resolve_config_path(env_str, self._env_files_base_path) + for env_str in self.env_files + ] async def initialize_pyrit_async(self) -> None: """ diff --git a/tests/unit/setup/test_configuration_loader.py b/tests/unit/setup/test_configuration_loader.py index b4e737bf5..4c070e75a 100644 --- a/tests/unit/setup/test_configuration_loader.py +++ b/tests/unit/setup/test_configuration_loader.py @@ -231,6 +231,26 @@ def test_resolve_initialization_scripts_relative_path(self): # Check path ends with expected components (works on both Unix and Windows) assert resolved[0].parts[-2:] == ("relative", "script.py") + @pytest.mark.parametrize( + ("field_name", "resolver_name", "relative_path"), + [ + ("initialization_scripts", "_resolve_initialization_scripts", "scripts/init.py"), + ("env_files", "_resolve_env_files", "env/local.env"), + ], + ) + def test_from_yaml_file_resolves_relative_paths_from_config_directory( + self, tmp_path, field_name, resolver_name, relative_path + ): + """Test relative paths from YAML are resolved from the config file directory.""" + config_path = tmp_path / "configs" / "pyrit.yaml" + config_path.parent.mkdir() + config_path.write_text(f"{field_name}:\n - ./{relative_path}\n", encoding="utf-8") + + config = ConfigurationLoader.from_yaml_file(config_path) + resolved = getattr(config, resolver_name)() + + assert resolved == [config_path.parent / relative_path] + def test_resolve_env_files_none_returns_none(self): """Test that None (default) returns None to signal 'use defaults'.""" config = ConfigurationLoader() @@ -421,6 +441,50 @@ def test_load_with_overrides_env_files_override(self, mock_default_path): assert config.env_files == ["/path/to/.env"] + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + @pytest.mark.parametrize( + ("field_name", "resolver_name", "relative_path"), + [ + ("initialization_scripts", "_resolve_initialization_scripts", "scripts/init.py"), + ("env_files", "_resolve_env_files", "env/local.env"), + ], + ) + def test_load_with_overrides_resolves_relative_paths_from_config_directory( + self, mock_default_path, tmp_path, field_name, resolver_name, relative_path + ): + """Test config file relative paths are resolved from the config file directory.""" + mock_default_path.exists.return_value = False + config_path = tmp_path / "configs" / "pyrit.yaml" + config_path.parent.mkdir() + config_path.write_text(f"{field_name}:\n - ./{relative_path}\n", encoding="utf-8") + + config = ConfigurationLoader.load_with_overrides(config_file=config_path) + resolved = getattr(config, resolver_name)() + + assert resolved == [config_path.parent / relative_path] + + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + @pytest.mark.parametrize( + ("field_name", "resolver_name", "relative_path"), + [ + ("initialization_scripts", "_resolve_initialization_scripts", "scripts/override.py"), + ("env_files", "_resolve_env_files", "env/override.env"), + ], + ) + def test_load_with_overrides_cli_relative_paths_use_cwd( + self, mock_default_path, tmp_path, field_name, resolver_name, relative_path + ): + """Test CLI path overrides keep resolving relative paths from the current directory.""" + mock_default_path.exists.return_value = False + config_path = tmp_path / "configs" / "pyrit.yaml" + config_path.parent.mkdir() + config_path.write_text(f"{field_name}:\n - ./from-config-placeholder\n", encoding="utf-8") + + config = ConfigurationLoader.load_with_overrides(config_file=config_path, **{field_name: [relative_path]}) + resolved = getattr(config, resolver_name)() + + assert resolved == [pathlib.Path.cwd() / relative_path] + @mock.patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") def test_load_with_overrides_converts_sequence_to_list(self, mock_default_path): """Test that Sequence inputs are converted to list for dataclass compatibility."""