Skip to content
Merged
1 change: 0 additions & 1 deletion .pyrit_conf_example
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ memory_db_type: sqlite
# - scorers: Registers pre-configured scorers into the ScorerRegistry
# - load_default_datasets: Loads default datasets for all registered scenarios
# - objective_list: Sets default objectives for scenarios
# - openai_objective_target: Sets up OpenAI target for scenarios
#
# Each initializer can be specified as:
# - A simple string (name only)
Expand Down
2 changes: 1 addition & 1 deletion pyrit/cli/_banner.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def add(line: str, role: ColorRole, segments: Optional[list[tuple[int, int, Colo
quick_start = [
"Quick Start:",
" pyrit> list-scenarios",
" pyrit> run foundry --initializers openai_objective_target load_default_datasets",
" pyrit> run foundry --target my_target --initializers targets load_default_datasets",
]
for qs in quick_start:
full_line = _box_line(" " + qs)
Expand Down
37 changes: 10 additions & 27 deletions pyrit/cli/_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def parse_memory_labels(json_string: str) -> dict[str, str]:
"Creates a new dataset config; fetches all items unless --max-dataset-size is also specified",
"max_dataset_size": "Maximum number of items to use from the dataset (must be >= 1). "
"Limits new datasets if --dataset-names provided, otherwise overrides scenario's default limit",
"target": "Name of a registered target from the TargetRegistry to use as the objective target. "
"Targets are registered by initializers (e.g., 'targets' initializer). "
"Use --list-targets to see available target names after initializers have run",
}


Expand Down Expand Up @@ -372,15 +375,14 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
"scenario_name": parts[0],
"initializers": None,
"initialization_scripts": None,
"env_files": None,
"scenario_strategies": None,
"max_concurrency": None,
"max_retries": None,
"memory_labels": None,
"database": None,
"log_level": None,
"dataset_names": None,
"max_dataset_size": None,
"target": None,
}

i = 1
Expand All @@ -399,13 +401,6 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
while i < len(parts) and not parts[i].startswith("--"):
result["initialization_scripts"].append(parts[i])
i += 1
elif parts[i] == "--env-files":
# Collect env file paths until next flag
result["env_files"] = []
i += 1
while i < len(parts) and not parts[i].startswith("--"):
result["env_files"].append(parts[i])
i += 1
elif parts[i] in ("--strategies", "-s"):
# Collect strategies until next flag
result["scenario_strategies"] = []
Expand All @@ -431,12 +426,6 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
raise ValueError("--memory-labels requires a value")
result["memory_labels"] = parse_memory_labels(parts[i])
i += 1
elif parts[i] == "--database":
i += 1
if i >= len(parts):
raise ValueError("--database requires a value")
result["database"] = validate_database(database=parts[i])
i += 1
elif parts[i] == "--log-level":
i += 1
if i >= len(parts):
Expand All @@ -456,6 +445,12 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
raise ValueError("--max-dataset-size requires a value")
result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1)
i += 1
elif parts[i] == "--target":
i += 1
if i >= len(parts):
raise ValueError("--target requires a value")
result["target"] = parts[i]
i += 1
else:
raise ValueError(f"Unknown argument: {parts[i]}")

Expand All @@ -470,24 +465,12 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]:
def add_common_arguments(parser: argparse.ArgumentParser) -> None:
"""Add arguments shared between pyrit_shell and pyrit_scan."""
parser.add_argument("--config-file", type=Path, help=ARG_HELP["config_file"])
parser.add_argument(
"--database",
type=validate_database_argparse,
default=None,
help=f"Database type to use ({IN_MEMORY}, {SQLITE}, {AZURE_SQL}). Defaults to config or {SQLITE}.",
)
parser.add_argument(
"--log-level",
type=validate_log_level_argparse,
default=logging.WARNING,
help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)",
)
parser.add_argument(
"--env-files",
type=str,
nargs="+",
help=ARG_HELP["env_files"],
)


# Module-level logger (stdlib only — no heavy deps)
Expand Down
126 changes: 124 additions & 2 deletions pyrit/cli/frontend_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pyrit.cli._cli_args import validate_integer as validate_integer
from pyrit.cli._cli_args import validate_log_level as validate_log_level
from pyrit.cli._cli_args import validate_log_level_argparse as validate_log_level_argparse
from pyrit.registry import InitializerRegistry, ScenarioRegistry
from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry
from pyrit.scenario import DatasetConfiguration
from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter
from pyrit.setup import ConfigurationLoader, initialize_pyrit_async
Expand Down Expand Up @@ -173,6 +173,8 @@ async def initialize_async(self) -> None:
initializers=None,
env_files=self._env_files,
)
# Mark that initial env loading has been printed
self._silent_reinit = True

# Load registries (use singleton pattern for shared access)
self._scenario_registry = ScenarioRegistry.get_registry_singleton()
Expand Down Expand Up @@ -253,10 +255,57 @@ async def list_initializers_async(
return context.initializer_registry.list_metadata()


async def list_targets_async(
*,
context: FrontendCore,
initializer_names: Optional[list[Any]] = None,
) -> list[str]:
"""
List available target names from the TargetRegistry.

Since targets are registered by initializers, this function requires initializers
to have been run first. If initializer_names are provided, they will be resolved
and run before querying the registry.

Args:
context: PyRIT context with loaded registries.
initializer_names: Optional list of initializer entries to run before listing.

Returns:
Sorted list of registered target names.
"""
if not context._initialized:
await context.initialize_async()

# If initializer names are provided, run them to populate the target registry
if initializer_names or context._initializer_configs:
configs = context._initializer_configs
if configs:
initializer_instances = []
for config in configs:
initializer_class = context.initializer_registry.get_class(config.name)
instance = initializer_class()
if config.args:
instance.set_params_from_args(args=config.args)
initializer_instances.append(instance)

await initialize_pyrit_async(
memory_db_type=context._database,
initialization_scripts=context._initialization_scripts,
initializers=initializer_instances,
env_files=context._env_files,
silent=getattr(context, "_silent_reinit", False),
)

target_registry = TargetRegistry.get_registry_singleton()
return target_registry.get_names()


async def run_scenario_async(
*,
scenario_name: str,
context: FrontendCore,
target_name: str | None = None,
scenario_strategies: Optional[list[str]] = None,
max_concurrency: Optional[int] = None,
max_retries: Optional[int] = None,
Expand All @@ -271,6 +320,9 @@ async def run_scenario_async(
Args:
scenario_name: Name of the scenario to run.
context: PyRIT context with loaded registries.
target_name: Name of a registered target from the TargetRegistry to use as the
objective target. Targets are registered by initializers (e.g., the 'targets'
initializer). Use --list-targets to see available names after initializers run.
scenario_strategies: Optional list of strategy names.
max_concurrency: Max concurrent operations.
max_retries: Max retry attempts.
Expand All @@ -287,7 +339,7 @@ async def run_scenario_async(
ScenarioResult: The result of the scenario execution.

Raises:
ValueError: If scenario not found or fails to run.
ValueError: If scenario not found, target not found, or fails to run.

Note:
Initializers from PyRITContext will be run before the scenario executes.
Expand Down Expand Up @@ -319,8 +371,27 @@ async def run_scenario_async(
initialization_scripts=context._initialization_scripts,
initializers=initializer_instances,
env_files=context._env_files,
silent=getattr(context, "_silent_reinit", False),
)

# Resolve objective target from TargetRegistry
if target_name is not None:
target_registry = TargetRegistry.get_registry_singleton()
objective_target = target_registry.get_instance_by_name(target_name)
if objective_target is None:
available_names = target_registry.get_names()
if not available_names:
raise ValueError(
f"Target '{target_name}' not found. The target registry is empty.\n"
"Targets are registered by initializers. Make sure to include an initializer "
"that registers targets (e.g., --initializers targets)."
)
raise ValueError(
f"Target '{target_name}' not found in registry.\nAvailable targets: {', '.join(available_names)}"
)
else:
objective_target = None

# Get scenario class
scenario_class = context.scenario_registry.get_class(scenario_name)

Expand All @@ -331,6 +402,9 @@ async def run_scenario_async(
# Build initialization kwargs (these go to initialize_async, not __init__)
init_kwargs: dict[str, Any] = {}

if objective_target is not None:
init_kwargs["objective_target"] = objective_target

if scenario_strategies:
strategy_class = scenario_class.get_strategy_class()
strategy_enums = []
Expand Down Expand Up @@ -580,3 +654,51 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path
print("\n" + "=" * 80)
print(f"\nTotal initializers: {len(initializers)}")
return 0


async def print_targets_list_async(*, context: FrontendCore) -> int:
"""
Print a formatted list of all available targets from the TargetRegistry.

Targets are registered by initializers, so this requires initializers to run first.
If no targets are found, prints a hint about using the 'targets' initializer.

Args:
context: PyRIT context with loaded registries.

Returns:
Exit code (0 for success).
"""
target_names = await list_targets_async(context=context)

if not target_names:
print("\nNo targets found in registry.")
print(
"\nTargets are registered by initializers. Include an initializer that registers "
"targets, for example:\n --initializers targets\n"
)
return 0

target_registry = TargetRegistry.get_registry_singleton()

print("\nRegistered Targets:")
print("=" * 80)
for name in target_names:
target = target_registry.get_instance_by_name(name)
if target is None:
print(f" {name}")
continue

model = target._underlying_model or target._model_name or ""
endpoint = target._endpoint or ""
class_name = type(target).__name__

_print_header(text=name)
print(f" Class: {class_name}")
if model:
print(f" Model: {model}")
if endpoint:
print(f" Endpoint: {endpoint}")
print("\n" + "=" * 80)
print(f"\nTotal targets: {len(target_names)}")
return 0
Loading
Loading