diff --git a/.pyrit_conf_example b/.pyrit_conf_example index 92ec37c4a..13abcc8df 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -28,7 +28,6 @@ memory_db_type: sqlite # - scorers: Registers pre-configured scorers into the ScorerRegistry # - load_default_datasets: Loads default datasets for all registered scenarios # - objective_list: Sets default objectives for scenarios -# - openai_objective_target: Sets up OpenAI target for scenarios # # Each initializer can be specified as: # - A simple string (name only) diff --git a/pyrit/cli/_banner.py b/pyrit/cli/_banner.py index 859cb107a..85101a4c0 100644 --- a/pyrit/cli/_banner.py +++ b/pyrit/cli/_banner.py @@ -296,7 +296,7 @@ def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, Colo quick_start = [ "Quick Start:", " pyrit> list-scenarios", - " pyrit> run foundry --initializers openai_objective_target load_default_datasets", + " pyrit> run foundry --target my_target --initializers targets load_default_datasets", ] for qs in quick_start: full_line = _box_line(" " + qs) diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index 8472a5afe..623ce7732 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -286,6 +286,9 @@ def parse_memory_labels(json_string: str) -> dict[str, str]: "Creates a new dataset config; fetches all items unless --max-dataset-size is also specified", "max_dataset_size": "Maximum number of items to use from the dataset (must be >= 1). " "Limits new datasets if --dataset-names provided, otherwise overrides scenario's default limit", + "target": "Name of a registered target from the TargetRegistry to use as the objective target. " + "Targets are registered by initializers (e.g., 'targets' initializer). " + "Use --list-targets to see available target names after initializers have run", } @@ -372,15 +375,14 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: "scenario_name": parts[0], "initializers": None, "initialization_scripts": None, - "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, "memory_labels": None, - "database": None, "log_level": None, "dataset_names": None, "max_dataset_size": None, + "target": None, } i = 1 @@ -399,13 +401,6 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: while i < len(parts) and not parts[i].startswith("--"): result["initialization_scripts"].append(parts[i]) i += 1 - elif parts[i] == "--env-files": - # Collect env file paths until next flag - result["env_files"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["env_files"].append(parts[i]) - i += 1 elif parts[i] in ("--strategies", "-s"): # Collect strategies until next flag result["scenario_strategies"] = [] @@ -431,12 +426,6 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: raise ValueError("--memory-labels requires a value") result["memory_labels"] = parse_memory_labels(parts[i]) i += 1 - elif parts[i] == "--database": - i += 1 - if i >= len(parts): - raise ValueError("--database requires a value") - result["database"] = validate_database(database=parts[i]) - i += 1 elif parts[i] == "--log-level": i += 1 if i >= len(parts): @@ -456,6 +445,12 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: raise ValueError("--max-dataset-size requires a value") result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) i += 1 + elif parts[i] == "--target": + i += 1 + if i >= len(parts): + raise ValueError("--target requires a value") + result["target"] = parts[i] + i += 1 else: raise ValueError(f"Unknown argument: {parts[i]}") @@ -470,24 +465,12 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: def add_common_arguments(parser: argparse.ArgumentParser) -> None: """Add arguments shared between pyrit_shell and pyrit_scan.""" parser.add_argument("--config-file", type=Path, help=ARG_HELP["config_file"]) - parser.add_argument( - "--database", - type=validate_database_argparse, - default=None, - help=f"Database type to use ({IN_MEMORY}, {SQLITE}, {AZURE_SQL}). Defaults to config or {SQLITE}.", - ) parser.add_argument( "--log-level", type=validate_log_level_argparse, default=logging.WARNING, help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)", ) - parser.add_argument( - "--env-files", - type=str, - nargs="+", - help=ARG_HELP["env_files"], - ) # Module-level logger (stdlib only — no heavy deps) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index ff4a6321c..e70fb4272 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -38,7 +38,7 @@ from pyrit.cli._cli_args import validate_integer as validate_integer from pyrit.cli._cli_args import validate_log_level as validate_log_level from pyrit.cli._cli_args import validate_log_level_argparse as validate_log_level_argparse -from pyrit.registry import InitializerRegistry, ScenarioRegistry +from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry from pyrit.scenario import DatasetConfiguration from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter from pyrit.setup import ConfigurationLoader, initialize_pyrit_async @@ -173,6 +173,8 @@ async def initialize_async(self) -> None: initializers=None, env_files=self._env_files, ) + # Mark that initial env loading has been printed + self._silent_reinit = True # Load registries (use singleton pattern for shared access) self._scenario_registry = ScenarioRegistry.get_registry_singleton() @@ -253,10 +255,57 @@ async def list_initializers_async( return context.initializer_registry.list_metadata() +async def list_targets_async( + *, + context: FrontendCore, + initializer_names: Optional[list[Any]] = None, +) -> list[str]: + """ + List available target names from the TargetRegistry. + + Since targets are registered by initializers, this function requires initializers + to have been run first. If initializer_names are provided, they will be resolved + and run before querying the registry. + + Args: + context: PyRIT context with loaded registries. + initializer_names: Optional list of initializer entries to run before listing. + + Returns: + Sorted list of registered target names. + """ + if not context._initialized: + await context.initialize_async() + + # If initializer names are provided, run them to populate the target registry + if initializer_names or context._initializer_configs: + configs = context._initializer_configs + if configs: + initializer_instances = [] + for config in configs: + initializer_class = context.initializer_registry.get_class(config.name) + instance = initializer_class() + if config.args: + instance.set_params_from_args(args=config.args) + initializer_instances.append(instance) + + await initialize_pyrit_async( + memory_db_type=context._database, + initialization_scripts=context._initialization_scripts, + initializers=initializer_instances, + env_files=context._env_files, + silent=getattr(context, "_silent_reinit", False), + ) + + target_registry = TargetRegistry.get_registry_singleton() + return target_registry.get_names() + + async def run_scenario_async( *, scenario_name: str, context: FrontendCore, + target_name: str | None = None, scenario_strategies: Optional[list[str]] = None, max_concurrency: Optional[int] = None, max_retries: Optional[int] = None, @@ -271,6 +320,9 @@ async def run_scenario_async( Args: scenario_name: Name of the scenario to run. context: PyRIT context with loaded registries. + target_name: Name of a registered target from the TargetRegistry to use as the + objective target. Targets are registered by initializers (e.g., the 'targets' + initializer). Use --list-targets to see available names after initializers run. scenario_strategies: Optional list of strategy names. max_concurrency: Max concurrent operations. max_retries: Max retry attempts. @@ -287,7 +339,7 @@ async def run_scenario_async( ScenarioResult: The result of the scenario execution. Raises: - ValueError: If scenario not found or fails to run. + ValueError: If scenario not found, target not found, or fails to run. Note: Initializers from PyRITContext will be run before the scenario executes. @@ -319,8 +371,27 @@ async def run_scenario_async( initialization_scripts=context._initialization_scripts, initializers=initializer_instances, env_files=context._env_files, + silent=getattr(context, "_silent_reinit", False), ) + # Resolve objective target from TargetRegistry + if target_name is not None: + target_registry = TargetRegistry.get_registry_singleton() + objective_target = target_registry.get_instance_by_name(target_name) + if objective_target is None: + available_names = target_registry.get_names() + if not available_names: + raise ValueError( + f"Target '{target_name}' not found. The target registry is empty.\n" + "Targets are registered by initializers. Make sure to include an initializer " + "that registers targets (e.g., --initializers targets)." + ) + raise ValueError( + f"Target '{target_name}' not found in registry.\nAvailable targets: {', '.join(available_names)}" + ) + else: + objective_target = None + # Get scenario class scenario_class = context.scenario_registry.get_class(scenario_name) @@ -331,6 +402,9 @@ async def run_scenario_async( # Build initialization kwargs (these go to initialize_async, not __init__) init_kwargs: dict[str, Any] = {} + if objective_target is not None: + init_kwargs["objective_target"] = objective_target + if scenario_strategies: strategy_class = scenario_class.get_strategy_class() strategy_enums = [] @@ -580,3 +654,51 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path print("\n" + "=" * 80) print(f"\nTotal initializers: {len(initializers)}") return 0 + + +async def print_targets_list_async(*, context: FrontendCore) -> int: + """ + Print a formatted list of all available targets from the TargetRegistry. + + Targets are registered by initializers, so this requires initializers to run first. + If no targets are found, prints a hint about using the 'targets' initializer. + + Args: + context: PyRIT context with loaded registries. + + Returns: + Exit code (0 for success). + """ + target_names = await list_targets_async(context=context) + + if not target_names: + print("\nNo targets found in registry.") + print( + "\nTargets are registered by initializers. Include an initializer that registers " + "targets, for example:\n --initializers targets\n" + ) + return 0 + + target_registry = TargetRegistry.get_registry_singleton() + + print("\nRegistered Targets:") + print("=" * 80) + for name in target_names: + target = target_registry.get_instance_by_name(name) + if target is None: + print(f" {name}") + continue + + model = target._underlying_model or target._model_name or "" + endpoint = target._endpoint or "" + class_name = type(target).__name__ + + _print_header(text=name) + print(f" Class: {class_name}") + if model: + print(f" Model: {model}") + if endpoint: + print(f" Endpoint: {endpoint}") + print("\n" + "=" * 80) + print(f"\nTotal targets: {len(target_names)}") + return 0 diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index e85fa56e4..cba8c1e89 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -29,23 +29,23 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: description="""PyRIT Scanner - Run security scenarios against AI systems Examples: - # List available scenarios and initializers + # List available scenarios, initializers, and targets pyrit_scan --list-scenarios pyrit_scan --list-initializers + pyrit_scan --list-targets --initializers targets - # Run a scenario with built-in initializers - pyrit_scan foundry --initializers openai_objective_target load_default_datasets + # Run a scenario with a target and initializers + pyrit_scan foundry --target my_target --initializers targets load_default_datasets # Run with a configuration file (recommended for complex setups) - pyrit_scan foundry --config-file ./my_config.yaml + pyrit_scan foundry --target my_target --config-file ./my_config.yaml # Run with custom initialization scripts - pyrit_scan garak.encoding --initialization-scripts ./my_config.py + pyrit_scan garak.encoding --target my_target --initialization-scripts ./my_config.py # Run specific strategies or options - pyrit_scan foundry --strategies base64 rot13 --initializers openai_objective_target - pyrit_scan foundry --initializers openai_objective_target --max-concurrency 10 --max-retries 3 - pyrit_scan garak.encoding --initializers openai_objective_target --memory-labels '{"run_id":"test123"}' + pyrit_scan foundry --target my_target --strategies base64 rot13 --initializers targets + pyrit_scan foundry --target my_target --initializers targets --max-concurrency 10 --max-retries 3 """, formatter_class=RawDescriptionHelpFormatter, ) @@ -75,6 +75,13 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: help="List all available scenario initializers and exit", ) + parser.add_argument( + "--list-targets", + action="store_true", + help="List all available targets from the TargetRegistry and exit. " + "Requires initializers that register targets (e.g., --initializers targets)", + ) + parser.add_argument( "scenario_name", type=str, @@ -82,17 +89,6 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: help="Name of the scenario to run", ) - parser.add_argument( - "--database", - type=frontend_core.validate_database_argparse, - default=None, - help=( - f"Database type to use for memory storage ({frontend_core.IN_MEMORY}, " - f"{frontend_core.SQLITE}, {frontend_core.AZURE_SQL}). " - f"Defaults to value from config file, or {frontend_core.SQLITE} if not specified." - ), - ) - parser.add_argument( "--initializers", type=frontend_core._parse_initializer_arg, @@ -107,13 +103,6 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: help=frontend_core.ARG_HELP["initialization_scripts"], ) - parser.add_argument( - "--env-files", - type=str, - nargs="+", - help=frontend_core.ARG_HELP["env_files"], - ) - parser.add_argument( "--strategies", "-s", @@ -154,6 +143,12 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: help=frontend_core.ARG_HELP["max_dataset_size"], ) + parser.add_argument( + "--target", + type=str, + help=frontend_core.ARG_HELP["target"], + ) + return parser.parse_args(args) @@ -185,19 +180,9 @@ def main(args: Optional[list[str]] = None) -> int: print(f"Error: {e}") return 1 - env_files = None - if parsed_args.env_files: - try: - env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) - except ValueError as e: - print(f"Error: {e}") - return 1 - context = frontend_core.FrontendCore( config_file=parsed_args.config_file, - database=parsed_args.database, initialization_scripts=initialization_scripts, - env_files=env_files, log_level=parsed_args.log_level, ) @@ -213,6 +198,15 @@ def main(args: Optional[list[str]] = None) -> int: ) return asyncio.run(frontend_core.print_initializers_list_async(context=context, discovery_path=scenarios_path)) + if parsed_args.list_targets: + # Need initializers to populate target registry + context = frontend_core.FrontendCore( + config_file=parsed_args.config_file, + initializer_names=parsed_args.initializers, + log_level=parsed_args.log_level, + ) + return asyncio.run(frontend_core.print_targets_list_async(context=context)) + # Verify scenario was provided if not parsed_args.scenario_name: print("Error: No scenario specified. Use --help for usage information.") @@ -227,18 +221,11 @@ def main(args: Optional[list[str]] = None) -> int: script_paths=parsed_args.initialization_scripts ) - # Collect environment files - env_files = None - if parsed_args.env_files: - env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) - # Create context with initializers context = frontend_core.FrontendCore( config_file=parsed_args.config_file, - database=parsed_args.database, initialization_scripts=initialization_scripts, initializer_names=parsed_args.initializers, - env_files=env_files, log_level=parsed_args.log_level, ) @@ -252,6 +239,7 @@ def main(args: Optional[list[str]] = None) -> int: frontend_core.run_scenario_async( scenario_name=parsed_args.scenario_name, context=context, + target_name=parsed_args.target, scenario_strategies=parsed_args.scenario_strategies, max_concurrency=parsed_args.max_concurrency, max_retries=parsed_args.max_retries, diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index f6020d319..26217facb 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: - from collections.abc import Sequence - from pyrit.cli import frontend_core from pyrit.models.scenario_result import ScenarioResult @@ -35,6 +33,7 @@ class PyRITShell(cmd.Cmd): Commands: list-scenarios - List all available scenarios list-initializers - List all available initializers + list-targets - List all available targets from the registry run [opts] - Run a scenario with optional parameters scenario-history - List all previous scenario runs print-scenario [N] - Print detailed results for scenario run(s) @@ -43,20 +42,18 @@ class PyRITShell(cmd.Cmd): exit (quit, q) - Exit the shell Shell Startup Options: - --database Database type (InMemory, SQLite, AzureSQL) - default for all runs + --config-file Path to config file (default: ~/.pyrit/.pyrit_conf) --log-level Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) - default for all runs - --env-files ... Environment files to load in order - default for all runs --no-animation Disable the animated startup banner Run Command Options: + --target Target name from the TargetRegistry (required) --initializers ... Built-in initializers (supports name:key=val1,val2 syntax) --initialization-scripts <...> Custom Python scripts to run before the scenario - --env-files ... Environment files to load in order (overrides startup default) --strategies, -s ... Strategy names to use --max-concurrency Maximum concurrent operations --max-retries Maximum retry attempts --memory-labels JSON string of labels - --database Override default database for this run --log-level Override default log level for this run """ @@ -131,9 +128,7 @@ def __init__( # Set by the background thread after importing frontend_core. self.context: Optional[frontend_core.FrontendCore] = None - self.default_database: Optional[str] = None self.default_log_level: Optional[int] = None - self.default_env_files: Optional[Sequence[Path]] = None # Initialize PyRIT in background thread for faster startup. self._init_thread = threading.Thread(target=self._background_init, daemon=True) @@ -151,9 +146,7 @@ def _background_init(self) -> None: self.context = self._deprecated_context else: self.context = fc.FrontendCore(**self._context_kwargs) - self.default_database = self.context._database self.default_log_level = self.context._log_level - self.default_env_files = self.context._env_files asyncio.run(self.context.initialize_async()) except BaseException as exc: self._init_error = exc @@ -212,6 +205,14 @@ def do_list_initializers(self, arg: str) -> None: except Exception as e: print(f"Error listing initializers: {e}") + def do_list_targets(self, arg: str) -> None: + """List all available targets from the TargetRegistry.""" + self._ensure_initialized() + try: + asyncio.run(self._fc.print_targets_list_async(context=self.context)) + except Exception as e: + print(f"Error listing targets: {e}") + def do_run(self, line: str) -> None: """ Run a scenario. @@ -220,47 +221,47 @@ def do_run(self, line: str) -> None: run [options] Options: + --target Target name from the TargetRegistry (required) --initializers ... Built-in initializers (supports name:key=val1,val2 syntax) --initialization-scripts <...> Custom Python scripts to run before the scenario - --env-files ... Environment files to load in order --strategies, -s ... Strategy names to use --max-concurrency Maximum concurrent operations --max-retries Maximum retry attempts --memory-labels JSON string of labels (e.g., '{"key":"value"}') - --database Override default database (InMemory, SQLite, AzureSQL) --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) Examples: - run garak.encoding --initializers openai_objective_target \ + run garak.encoding --target my_target --initializers targets \ load_default_datasets - run garak.encoding --initializers custom_target \ + run garak.encoding --target my_target --initializers targets \ load_default_datasets --strategies base64 rot13 - run foundry --initializers target:tags=default,scorer \ + run foundry --target my_target --initializers targets:tags=default,scorer \ dataset:mode=strict --strategies base64 - run foundry --initializers openai_objective_target \ + run foundry --target my_target --initializers targets \ load_default_datasets --max-concurrency 10 --max-retries 3 - run garak.encoding --initializers custom_target \ + run garak.encoding --target my_target --initializers targets \ load_default_datasets \ --memory-labels '{"run_id":"test123","env":"dev"}' - run foundry --initializers openai_objective_target \ + run foundry --target my_target --initializers targets \ load_default_datasets -s jailbreak crescendo - run garak.encoding --initializers openai_objective_target \ - load_default_datasets --database InMemory --log-level DEBUG - run foundry --initialization-scripts ./my_custom_init.py -s all + run garak.encoding --target my_target --initializers targets \ + load_default_datasets --log-level DEBUG + run foundry --target my_target --initialization-scripts ./my_custom_init.py -s all Note: - Every scenario requires an initializer (--initializers or --initialization-scripts). - Database and log-level defaults are set at shell startup but can be overridden per-run. - Initializers are specified per-run to allow different setups for different scenarios. + --target is required for every run. + Initializers can be specified per-run or configured in .pyrit_conf. + Database and env-files are configured via the config file. """ self._ensure_initialized() if not line.strip(): print("Error: Specify a scenario name") print("\nUsage: run [options]") - print("\nNote: Every scenario requires an initializer.") + print("\nNote: --target is required. Initializers can be specified per-run or in .pyrit_conf.") print("\nOptions:") - print(f" --initializers ... {self._fc.ARG_HELP['initializers']} (REQUIRED)") + print(f" --target {self._fc.ARG_HELP['target']}") + print(f" --initializers ... {self._fc.ARG_HELP['initializers']}") print( f" --initialization-scripts <...> {self._fc.ARG_HELP['initialization_scripts']}" " (alternative to --initializers)" @@ -269,15 +270,11 @@ def do_run(self, line: str) -> None: print(f" --max-concurrency {self._fc.ARG_HELP['max_concurrency']}") print(f" --max-retries {self._fc.ARG_HELP['max_retries']}") print(f" --memory-labels {self._fc.ARG_HELP['memory_labels']}") - print( - f" --database Override default database" - f" ({self._fc.IN_MEMORY}, {self._fc.SQLITE}, {self._fc.AZURE_SQL})" - ) print( " --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)" ) print("\nExample:") - print(" run foundry --initializers openai_objective_target load_default_datasets") + print(" run foundry --target my_target --initializers targets load_default_datasets") print("\nType 'help run' for more details and examples") return @@ -297,24 +294,10 @@ def do_run(self, line: str) -> None: print(f"Error: {e}") return - # Resolve env files if provided - resolved_env_files: Optional[list[Path]] = None - if args["env_files"]: - try: - resolved_env_files = list(self._fc.resolve_env_files(env_file_paths=args["env_files"])) - except ValueError as e: - print(f"Error: {e}") - return - else: - # Use default env files from shell startup - resolved_env_files = list(self.default_env_files) if self.default_env_files else None - # Create a context for this run with overrides run_context = self._fc.FrontendCore( - database=args["database"] or self.default_database, initialization_scripts=resolved_scripts, initializer_names=args["initializers"], - env_files=resolved_env_files, log_level=args["log_level"] if args["log_level"] else self.default_log_level, ) # Use the existing registries (don't reinitialize) @@ -327,6 +310,7 @@ def do_run(self, line: str) -> None: self._fc.run_scenario_async( scenario_name=args["scenario_name"], context=run_context, + target_name=args["target"], scenario_strategies=args["scenario_strategies"], max_concurrency=args["max_concurrency"], max_retries=args["max_retries"], @@ -424,17 +408,16 @@ def do_print_scenario(self, arg: str) -> None: def do_help(self, arg: str) -> None: """Show help. Usage: help [command].""" if not arg: - from pyrit.cli._cli_args import ARG_HELP, AZURE_SQL, IN_MEMORY, SQLITE + from pyrit.cli._cli_args import ARG_HELP # Show general help (no full init needed — ARG_HELP is lightweight) super().do_help(arg) print("\n" + "=" * 70) print("Shell Startup Options:") print("=" * 70) - print(" --database ") - print(" Default database type: InMemory, SQLite, or AzureSQL") - print(" Default: SQLite") - print(" Can be overridden per-run with 'run --database '") + print(" --config-file ") + print(" Path to YAML configuration file") + print(" Default: ~/.pyrit/.pyrit_conf") print() print(" --log-level ") print(" Default logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL") @@ -444,13 +427,17 @@ def do_help(self, arg: str) -> None: print("=" * 70) print("Run Command Options (specified when running scenarios):") print("=" * 70) - print(" --initializers [ ...] (REQUIRED)") + print(" --target (REQUIRED)") + print(f" {ARG_HELP['target']}") + print(" Example: run foundry --target my_target --initializers targets load_default_datasets") + print() + print(" --initializers [ ...]") print(f" {ARG_HELP['initializers']}") - print(" Every scenario requires at least one initializer") - print(" Example: run foundry --initializers openai_objective_target load_default_datasets") - print(" With params: run foundry --initializers target:tags=default,scorer") + print(" Example: run foundry --target my_target --initializers targets load_default_datasets") + print(" With params: run foundry --target my_target --initializers targets:tags=default,scorer") print( - " Multiple with params: run foundry --initializers target:tags=default,scorer dataset:mode=strict" + " Multiple with params: run foundry --target my_target" + " --initializers targets:tags=default,scorer dataset:mode=strict" ) print() print(" --initialization-scripts [ ...] (Alternative to --initializers)") @@ -471,12 +458,13 @@ def do_help(self, arg: str) -> None: print(f" {ARG_HELP['memory_labels']}") print(' Example: run foundry --memory-labels \'{"env":"test"}\'') print() - print(f" --database Override ({IN_MEMORY}, {SQLITE}, {AZURE_SQL})") print(" --log-level Override (DEBUG, INFO, WARNING, ERROR, CRITICAL)") print() + print(" Database and env-files are configured via the config file (--config-file).") + print() print("Start the shell like:") print(" pyrit_shell") - print(" pyrit_shell --database InMemory --log-level DEBUG") + print(" pyrit_shell --config-file ./my_config.yaml --log-level DEBUG") else: # Show help for specific command super().do_help(arg) @@ -538,7 +526,7 @@ def main() -> int: """ import argparse - from pyrit.cli._cli_args import ARG_HELP, AZURE_SQL, IN_MEMORY, SQLITE, validate_log_level + from pyrit.cli._cli_args import ARG_HELP, validate_log_level parser = argparse.ArgumentParser( prog="pyrit_shell", @@ -551,16 +539,6 @@ def main() -> int: help=ARG_HELP["config_file"], ) - parser.add_argument( - "--database", - choices=[IN_MEMORY, SQLITE, AZURE_SQL], - default=None, - help=( - f"Default database type to use ({IN_MEMORY}, {SQLITE}, {AZURE_SQL})" - f" (defaults to config file value, or {SQLITE} if not specified)" - ), - ) - parser.add_argument( "--log-level", type=str, @@ -572,13 +550,6 @@ def main() -> int: ), ) - parser.add_argument( - "--env-files", - type=str, - nargs="+", - help=ARG_HELP["env_files"], - ) - parser.add_argument( "--no-animation", action="store_true", @@ -588,17 +559,6 @@ def main() -> int: args = parser.parse_args() - # Resolve and validate env file paths (lightweight — no heavy imports needed). - env_files: Optional[list[Path]] = None - if args.env_files: - from pyrit.cli._cli_args import resolve_env_files - - try: - env_files = resolve_env_files(env_file_paths=args.env_files) - except ValueError as e: - print(f"Error: {e}") - return 1 - # Play the banner immediately, before heavy imports. # Suppress logging so background-thread output doesn't corrupt the animation. prev_disable = logging.root.manager.disable @@ -615,10 +575,6 @@ def main() -> int: shell = PyRITShell( no_animation=args.no_animation, config_file=args.config_file, - database=args.database, - initialization_scripts=None, - initializer_names=None, - env_files=env_files, log_level=validate_log_level(log_level=args.log_level), ) shell.cmdloop(intro=intro) diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 069208e3c..ecaf5e1c3 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -220,7 +220,7 @@ def get_default_capabilities(cls, underlying_model: Optional[str]) -> TargetCapa known = TargetCapabilities.get_known_capabilities(underlying_model) if known is not None: return known - logger.warning( + logger.info( "No known capabilities for model '%s'. Falling back to %s._DEFAULT_CAPABILITIES.", underlying_model, cls.__name__, diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 845bc2d2e..5e2ce65b9 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -231,6 +231,8 @@ def load_with_overrides( default_config_path = DEFAULT_CONFIG_PATH if default_config_path.exists(): try: + logger.info(f"Loading default configuration file: {default_config_path}") + print(f"Loading default configuration file: {default_config_path}") default_config = ConfigurationLoader.from_yaml_file(default_config_path) config_data["memory_db_type"] = default_config.memory_db_type config_data["initializers"] = [ @@ -252,6 +254,8 @@ def load_with_overrides( if config_file is not None: if not config_file.exists(): raise FileNotFoundError(f"Configuration file not found: {config_file}") + logger.info(f"Loading configuration file: {config_file}") + print(f"Loading configuration file: {config_file}") explicit_config = ConfigurationLoader.from_yaml_file(config_file) config_data["memory_db_type"] = explicit_config.memory_db_type config_data["initializers"] = [ diff --git a/pyrit/setup/initializers/__init__.py b/pyrit/setup/initializers/__init__.py index d27fc41c2..f5331f502 100644 --- a/pyrit/setup/initializers/__init__.py +++ b/pyrit/setup/initializers/__init__.py @@ -9,7 +9,6 @@ from pyrit.setup.initializers.pyrit_initializer import InitializerParameter, PyRITInitializer from pyrit.setup.initializers.scenarios.load_default_datasets import LoadDefaultDatasets from pyrit.setup.initializers.scenarios.objective_list import ScenarioObjectiveListInitializer -from pyrit.setup.initializers.scenarios.openai_objective_target import ScenarioObjectiveTargetInitializer from pyrit.setup.initializers.simple import SimpleInitializer __all__ = [ @@ -21,5 +20,4 @@ "SimpleInitializer", "LoadDefaultDatasets", "ScenarioObjectiveListInitializer", - "ScenarioObjectiveTargetInitializer", ] diff --git a/pyrit/setup/initializers/scenarios/openai_objective_target.py b/pyrit/setup/initializers/scenarios/openai_objective_target.py deleted file mode 100644 index d27fc60c0..000000000 --- a/pyrit/setup/initializers/scenarios/openai_objective_target.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -OpenAI Objective Target Scenario Configuration. - -This simply sets the target to use OpenAIChatTarget with basic settings. - -It will likely need to be modified based on the target you are testing. But this will work -with OpenAI targets if you set OPENAI_CLI_ENDPOINT -""" - -import os - -from pyrit.common.apply_defaults import set_default_value -from pyrit.prompt_target import OpenAIChatTarget -from pyrit.scenario import Scenario -from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer - - -class ScenarioObjectiveTargetInitializer(PyRITInitializer): - """Configure a simple objective target for use in PyRIT scenarios.""" - - @property - def name(self) -> str: - """Return the display name of this initializer.""" - return "Simple Objective Target Configuration for Scenarios" - - @property - def execution_order(self) -> int: - """Should be executed after most initializers.""" - return 10 - - @property - def description(self) -> str: - """Describe the objective target configuration of this initializer.""" - return ( - "This configuration sets up a simple objective target for scenarios " - "using OpenAIChatTarget with basic settings. It initializes an openAI chat target " - "using the OPENAI_CLI_ENDPOINT and OPENAI_CLI_KEY environment variables." - ) - - @property - def required_env_vars(self) -> list[str]: - """Get list of required environment variables.""" - return [ - "DEFAULT_OPENAI_FRONTEND_ENDPOINT", - ] - - async def initialize_async(self) -> None: - """Set default objective target for scenarios that accept them.""" - objective_target = OpenAIChatTarget( - endpoint=os.getenv("DEFAULT_OPENAI_FRONTEND_ENDPOINT"), - api_key=os.getenv("DEFAULT_OPENAI_FRONTEND_KEY"), - model_name=os.getenv("DEFAULT_OPENAI_FRONTEND_MODEL"), - ) - - set_default_value( - class_type=Scenario, - parameter_name="objective_target", - value=objective_target, - ) diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 7c040deb5..8275fd87a 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -582,27 +582,27 @@ def test_parse_run_arguments_with_initializers(self): def test_parse_run_arguments_with_initializer_params(self): """Test parsing initializers with key=value params.""" result = frontend_core.parse_run_arguments( - args_string="test_scenario --initializers simple target:tags=default" + args_string="test_scenario --initializers simple targets:tags=default" ) assert result["initializers"][0] == "simple" - assert result["initializers"][1] == {"name": "target", "args": {"tags": ["default"]}} + assert result["initializers"][1] == {"name": "targets", "args": {"tags": ["default"]}} def test_parse_run_arguments_with_initializer_multiple_params(self): """Test parsing initializers with multiple key=value params separated by semicolons.""" result = frontend_core.parse_run_arguments( - args_string="test_scenario --initializers target:tags=default;mode=strict" + args_string="test_scenario --initializers targets:tags=default;mode=strict" ) - assert result["initializers"][0] == {"name": "target", "args": {"tags": ["default"], "mode": ["strict"]}} + assert result["initializers"][0] == {"name": "targets", "args": {"tags": ["default"], "mode": ["strict"]}} def test_parse_run_arguments_with_initializer_comma_list(self): """Test parsing initializer params with comma-separated values into lists.""" result = frontend_core.parse_run_arguments( - args_string="test_scenario --initializers target:tags=default,scorer" + args_string="test_scenario --initializers targets:tags=default,scorer" ) - assert result["initializers"][0] == {"name": "target", "args": {"tags": ["default", "scorer"]}} + assert result["initializers"][0] == {"name": "targets", "args": {"tags": ["default", "scorer"]}} def test_parse_run_arguments_with_strategies(self): """Test parsing with strategies.""" @@ -634,12 +634,6 @@ def test_parse_run_arguments_with_memory_labels(self): assert result["memory_labels"] == {"key": "value"} - def test_parse_run_arguments_with_database(self): - """Test parsing with database override.""" - result = frontend_core.parse_run_arguments(args_string=f"test_scenario --database {frontend_core.IN_MEMORY}") - - assert result["database"] == frontend_core.IN_MEMORY - def test_parse_run_arguments_with_log_level(self): """Test parsing with log-level override.""" result = frontend_core.parse_run_arguments(args_string="test_scenario --log-level DEBUG") @@ -922,9 +916,235 @@ def test_arg_help_contains_all_keys(self): "memory_labels", "database", "log_level", + "target", ] for key in expected_keys: assert key in frontend_core.ARG_HELP assert isinstance(frontend_core.ARG_HELP[key], str) assert len(frontend_core.ARG_HELP[key]) > 0 + + +class TestParseRunArgumentsTarget: + """Tests for --target parsing in parse_run_arguments.""" + + def test_parse_run_arguments_with_target(self): + """Test parsing with --target.""" + result = frontend_core.parse_run_arguments(args_string="test_scenario --target my_target") + + assert result["target"] == "my_target" + + def test_parse_run_arguments_target_with_other_args(self): + """Test parsing --target alongside other arguments.""" + result = frontend_core.parse_run_arguments( + args_string="test_scenario --target my_target --initializers init1 --max-concurrency 5" + ) + + assert result["target"] == "my_target" + assert result["initializers"] == ["init1"] + assert result["max_concurrency"] == 5 + + def test_parse_run_arguments_target_missing_value(self): + """Test parsing --target without a value raises ValueError.""" + with pytest.raises(ValueError, match="--target requires a value"): + frontend_core.parse_run_arguments(args_string="test_scenario --target") + + def test_parse_run_arguments_no_target(self): + """Test parsing without --target returns None.""" + result = frontend_core.parse_run_arguments(args_string="test_scenario") + + assert result["target"] is None + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +class TestRunScenarioAsyncTarget: + """Tests for target resolution in run_scenario_async.""" + + @patch("pyrit.cli.frontend_core.TargetRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") + async def test_run_scenario_async_with_valid_target( + self, + mock_printer_class: MagicMock, + mock_init: AsyncMock, + mock_target_registry_class: MagicMock, + ): + """Test running scenario with a valid target name resolves from registry.""" + # Setup mocks + mock_target = MagicMock() + mock_registry = MagicMock() + mock_registry.get_instance_by_name.return_value = mock_target + mock_target_registry_class.get_registry_singleton.return_value = mock_registry + + context = frontend_core.FrontendCore() + mock_scenario_registry = MagicMock() + mock_scenario_class = MagicMock() + mock_scenario_instance = MagicMock() + mock_result = MagicMock() + mock_printer = MagicMock() + mock_printer.print_summary_async = AsyncMock() + + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) + mock_scenario_class.return_value = mock_scenario_instance + mock_scenario_registry.get_class.return_value = mock_scenario_class + mock_printer_class.return_value = mock_printer + + context._scenario_registry = mock_scenario_registry + context._initializer_registry = MagicMock() + context._initialized = True + + result = await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + target_name="my_target", + ) + + assert result == mock_result + mock_registry.get_instance_by_name.assert_called_once_with("my_target") + # Verify objective_target was passed to initialize_async + call_kwargs = mock_scenario_instance.initialize_async.call_args[1] + assert call_kwargs["objective_target"] is mock_target + + @patch("pyrit.cli.frontend_core.TargetRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + async def test_run_scenario_async_with_invalid_target( + self, + mock_init: AsyncMock, + mock_target_registry_class: MagicMock, + ): + """Test running scenario with an invalid target name raises ValueError.""" + mock_registry = MagicMock() + mock_registry.get_instance_by_name.return_value = None + mock_registry.get_names.return_value = ["target_a", "target_b"] + mock_target_registry_class.get_registry_singleton.return_value = mock_registry + + context = frontend_core.FrontendCore() + context._scenario_registry = MagicMock() + context._initializer_registry = MagicMock() + context._initialized = True + + with pytest.raises(ValueError, match="Target 'bad_target' not found in registry"): + await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + target_name="bad_target", + ) + + @patch("pyrit.cli.frontend_core.TargetRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + async def test_run_scenario_async_with_empty_target_registry( + self, + mock_init: AsyncMock, + mock_target_registry_class: MagicMock, + ): + """Test running scenario with target name when registry is empty gives helpful error.""" + mock_registry = MagicMock() + mock_registry.get_instance_by_name.return_value = None + mock_registry.get_names.return_value = [] + mock_target_registry_class.get_registry_singleton.return_value = mock_registry + + context = frontend_core.FrontendCore() + context._scenario_registry = MagicMock() + context._initializer_registry = MagicMock() + context._initialized = True + + with pytest.raises(ValueError, match="target registry is empty"): + await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + target_name="my_target", + ) + + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") + async def test_run_scenario_async_without_target( + self, + mock_printer_class: MagicMock, + mock_init: AsyncMock, + ): + """Test running scenario without target_name does not add objective_target to kwargs.""" + context = frontend_core.FrontendCore() + mock_scenario_registry = MagicMock() + mock_scenario_class = MagicMock() + mock_scenario_instance = MagicMock() + mock_result = MagicMock() + mock_printer = MagicMock() + mock_printer.print_summary_async = AsyncMock() + + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) + mock_scenario_class.return_value = mock_scenario_instance + mock_scenario_registry.get_class.return_value = mock_scenario_class + mock_printer_class.return_value = mock_printer + + context._scenario_registry = mock_scenario_registry + context._initializer_registry = MagicMock() + context._initialized = True + + await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + ) + + # Verify no objective_target was passed + call_kwargs = mock_scenario_instance.initialize_async.call_args[1] + assert "objective_target" not in call_kwargs + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +class TestPrintTargetsList: + """Tests for print_targets_list_async function.""" + + @patch("pyrit.cli.frontend_core.TargetRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + async def test_print_targets_list_with_targets( + self, + mock_init: AsyncMock, + mock_target_registry_class: MagicMock, + capsys, + ): + """Test print_targets_list_async displays target names.""" + mock_registry = MagicMock() + mock_registry.get_names.return_value = ["target_a", "target_b"] + mock_target_registry_class.get_registry_singleton.return_value = mock_registry + + context = frontend_core.FrontendCore() + context._scenario_registry = MagicMock() + context._initializer_registry = MagicMock() + context._initialized = True + + result = await frontend_core.print_targets_list_async(context=context) + + assert result == 0 + captured = capsys.readouterr() + assert "target_a" in captured.out + assert "target_b" in captured.out + assert "Total targets: 2" in captured.out + + @patch("pyrit.cli.frontend_core.TargetRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) + async def test_print_targets_list_empty( + self, + mock_init: AsyncMock, + mock_target_registry_class: MagicMock, + capsys, + ): + """Test print_targets_list_async with no targets gives helpful hint.""" + mock_registry = MagicMock() + mock_registry.get_names.return_value = [] + mock_target_registry_class.get_registry_singleton.return_value = mock_registry + + context = frontend_core.FrontendCore() + context._scenario_registry = MagicMock() + context._initializer_registry = MagicMock() + context._initialized = True + + result = await frontend_core.print_targets_list_async(context=context) + + assert result == 0 + captured = capsys.readouterr() + assert "No targets found" in captured.out + assert "--initializers targets" in captured.out diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index a1fdff772..204376a1b 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -36,15 +36,8 @@ def test_parse_args_scenario_name_only(self): args = pyrit_scan.parse_args(["test_scenario"]) assert args.scenario_name == "test_scenario" - assert args.database is None assert args.log_level == logging.WARNING - def test_parse_args_with_database(self): - """Test parsing with database option.""" - args = pyrit_scan.parse_args(["test_scenario", "--database", "InMemory"]) - - assert args.database == "InMemory" - def test_parse_args_with_log_level(self): """Test parsing with log-level option.""" args = pyrit_scan.parse_args(["test_scenario", "--log-level", "DEBUG"]) @@ -98,8 +91,6 @@ def test_parse_args_complex_command(self): args = pyrit_scan.parse_args( [ "encoding_scenario", - "--database", - "InMemory", "--log-level", "INFO", "--initializers", @@ -117,7 +108,6 @@ def test_parse_args_complex_command(self): ) assert args.scenario_name == "encoding_scenario" - assert args.database == "InMemory" assert args.log_level == logging.INFO assert args.initializers == ["openai_target"] assert args.scenario_strategies == ["base64", "rot13"] @@ -125,11 +115,6 @@ def test_parse_args_complex_command(self): assert args.max_retries == 5 assert args.memory_labels == '{"env":"test"}' - def test_parse_args_invalid_database(self): - """Test parsing with invalid database raises error.""" - with pytest.raises(SystemExit): - pyrit_scan.parse_args(["test_scenario", "--database", "InvalidDB"]) - def test_parse_args_invalid_log_level(self): """Test parsing with invalid log level raises error.""" with pytest.raises(SystemExit): @@ -152,6 +137,24 @@ def test_parse_args_help_flag(self): assert exc_info.value.code == 0 + def test_parse_args_with_target(self): + """Test parsing with --target option.""" + args = pyrit_scan.parse_args(["test_scenario", "--target", "my_target"]) + + assert args.target == "my_target" + + def test_parse_args_target_default_is_none(self): + """Test --target defaults to None when not provided.""" + args = pyrit_scan.parse_args(["test_scenario"]) + + assert args.target is None + + def test_parse_args_with_list_targets(self): + """Test parsing --list-targets flag.""" + args = pyrit_scan.parse_args(["--list-targets"]) + + assert args.list_targets is True + class TestMain: """Tests for main function.""" @@ -280,8 +283,6 @@ def test_main_run_scenario_with_all_options( result = pyrit_scan.main( [ "test_scenario", - "--database", - "InMemory", "--log-level", "DEBUG", "--initializers", @@ -304,7 +305,6 @@ def test_main_run_scenario_with_all_options( # Verify FrontendCore was called with correct args call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["database"] == "InMemory" assert call_kwargs["log_level"] == logging.DEBUG assert call_kwargs["initializer_names"] == ["init1", "init2"] @@ -341,14 +341,6 @@ def test_main_run_scenario_with_exception( assert result == 1 - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_database_defaults_to_none(self, mock_frontend_core: MagicMock): - """Test main passes None for database when not specified (config file determines default).""" - pyrit_scan.main(["--list-scenarios"]) - - call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["database"] is None - @patch("pyrit.cli.frontend_core.FrontendCore") def test_main_log_level_defaults_to_warning(self, mock_frontend_core: MagicMock): """Test main uses WARNING as default log level.""" diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 58043c550..2dea18a5a 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -54,9 +54,7 @@ def shell(): s._fc = fc_module s.context = mock_context - s.default_database = mock_context._database s.default_log_level = mock_context._log_level - s.default_env_files = mock_context._env_files s._init_complete.set() yield s, mock_context, mock_fc_class @@ -73,7 +71,6 @@ def test_init(self, mock_fc): assert shell._init_complete.is_set() assert shell.context is ctx - assert shell.default_database == "SQLite" assert shell.default_log_level == "WARNING" assert shell._scenario_history == [] mock_fc_class.assert_called_once_with() @@ -214,15 +211,14 @@ def test_do_run_basic_scenario( "scenario_name": "test_scenario", "initializers": ["test_init"], "initialization_scripts": None, - "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, "memory_labels": None, - "database": None, "log_level": None, "dataset_names": None, "max_dataset_size": None, + "target": None, } mock_result = MagicMock() @@ -268,15 +264,14 @@ def test_do_run_with_initialization_scripts( "scenario_name": "test_scenario", "initializers": None, "initialization_scripts": ["script.py"], - "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, "memory_labels": None, - "database": None, "log_level": None, "dataset_names": None, "max_dataset_size": None, + "target": None, } mock_resolve_scripts.return_value = [Path("/test/script.py")] @@ -303,15 +298,14 @@ def test_do_run_with_missing_script( "scenario_name": "test_scenario", "initializers": None, "initialization_scripts": ["missing.py"], - "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, "memory_labels": None, - "database": None, "log_level": None, "dataset_names": None, "max_dataset_size": None, + "target": None, } mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") @@ -321,41 +315,6 @@ def test_do_run_with_missing_script( captured = capsys.readouterr() assert "Error: Script not found" in captured.out - @patch("pyrit.cli.pyrit_shell.asyncio.run") - @patch("pyrit.cli.frontend_core.parse_run_arguments") - def test_do_run_with_database_override( - self, - mock_parse_args: MagicMock, - mock_asyncio_run: MagicMock, - shell, - ): - """Test do_run with database override.""" - s, ctx, _ = shell - - mock_parse_args.return_value = { - "scenario_name": "test_scenario", - "initializers": ["test_init"], - "initialization_scripts": None, - "env_files": None, - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "database": "InMemory", - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - } - - mock_asyncio_run.side_effect = [MagicMock()] - - with patch("pyrit.cli.frontend_core.FrontendCore") as mock_frontend: - s.do_run("test_scenario --initializers test_init --database InMemory") - - # Verify FrontendCore was created with overridden database - call_kwargs = mock_frontend.call_args[1] - assert call_kwargs["database"] == "InMemory" - @patch("pyrit.cli.pyrit_shell.asyncio.run") @patch("pyrit.cli.frontend_core.parse_run_arguments") def test_do_run_with_exception( @@ -372,15 +331,14 @@ def test_do_run_with_exception( "scenario_name": "test_scenario", "initializers": ["test_init"], "initialization_scripts": None, - "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, "memory_labels": None, - "database": None, "log_level": None, "dataset_names": None, "max_dataset_size": None, + "target": None, } mock_asyncio_run.side_effect = [ValueError("Test error")] @@ -637,23 +595,22 @@ def test_main_default_args(self, mock_play: MagicMock, mock_shell_class: MagicMo assert result == 0 call_kwargs = mock_shell_class.call_args[1] - assert call_kwargs["database"] is None assert call_kwargs["log_level"] == logging.WARNING mock_shell.cmdloop.assert_called_once() @patch("pyrit.cli.pyrit_shell.PyRITShell") @patch("pyrit.cli._banner.play_animation", return_value="") - def test_main_with_database_arg(self, mock_play: MagicMock, mock_shell_class: MagicMock): - """Test main with database argument.""" + def test_main_with_config_file_arg(self, mock_play: MagicMock, mock_shell_class: MagicMock): + """Test main with config-file argument.""" mock_shell = MagicMock() mock_shell_class.return_value = mock_shell - with patch("sys.argv", ["pyrit_shell", "--database", "InMemory"]): + with patch("sys.argv", ["pyrit_shell", "--config-file", "my_config.yaml"]): result = pyrit_shell.main() assert result == 0 call_kwargs = mock_shell_class.call_args[1] - assert call_kwargs["database"] == "InMemory" + assert call_kwargs["config_file"] == Path("my_config.yaml") @patch("pyrit.cli.pyrit_shell.PyRITShell") @patch("pyrit.cli._banner.play_animation", return_value="") @@ -710,8 +667,9 @@ def test_main_creates_context_without_initializers(self, mock_play: MagicMock, m pyrit_shell.main() call_kwargs = mock_shell_class.call_args[1] - assert call_kwargs["initialization_scripts"] is None - assert call_kwargs["initializer_names"] is None + # main() should not pass initialization_scripts or initializer_names + assert "initialization_scripts" not in call_kwargs + assert "initializer_names" not in call_kwargs @patch("pyrit.cli.pyrit_shell.PyRITShell") @patch("pyrit.cli._banner.play_animation", return_value="") @@ -760,15 +718,14 @@ def test_run_with_all_parameters( "scenario_name": "test_scenario", "initializers": ["init1"], "initialization_scripts": None, - "env_files": None, "scenario_strategies": ["s1", "s2"], "max_concurrency": 10, "max_retries": 5, "memory_labels": {"key": "value"}, - "database": "InMemory", "log_level": "DEBUG", "dataset_names": None, "max_dataset_size": None, + "target": None, } mock_asyncio_run.side_effect = [MagicMock()] @@ -795,15 +752,14 @@ def test_run_stores_result_in_history( "scenario_name": "test_scenario", "initializers": ["test_init"], "initialization_scripts": None, - "env_files": None, "scenario_strategies": None, "max_concurrency": None, "max_retries": None, "memory_labels": None, - "database": None, "log_level": None, "dataset_names": None, "max_dataset_size": None, + "target": None, } mock_result1 = MagicMock() diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py index 0ffb2688c..f64d7d320 100644 --- a/tests/unit/target/test_target_capabilities.py +++ b/tests/unit/target/test_target_capabilities.py @@ -397,8 +397,8 @@ def test_returns_class_default_and_warns_when_model_is_unrecognized(self): cls = self._make_target_class(default_caps=custom_caps) with patch("pyrit.prompt_target.common.prompt_target.logger") as mock_logger: result = cls.get_default_capabilities("totally-unknown-model") - mock_logger.warning.assert_called_once() - warning_args = mock_logger.warning.call_args[0] + mock_logger.info.assert_called_once() + warning_args = mock_logger.info.call_args[0] assert "totally-unknown-model" in warning_args[1] assert result is custom_caps