From 905bac6d2a588738ee3eb8ff5e7ecb99c0d643ef Mon Sep 17 00:00:00 2001 From: biefan <70761325+biefan@users.noreply.github.com> Date: Tue, 17 Mar 2026 03:44:24 +0000 Subject: [PATCH 01/14] Handle PyRIT shell initialization failures --- pyrit/cli/pyrit_shell.py | 19 ++++++++++-- tests/unit/cli/test_pyrit_shell.py | 50 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index a789a0dda..2efa5d2c0 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -84,12 +84,22 @@ def __init__( # Initialize PyRIT in background thread for faster startup. self._init_thread = threading.Thread(target=self._background_init, daemon=True) self._init_complete = threading.Event() + self._init_error: Optional[BaseException] = None self._init_thread.start() def _background_init(self) -> None: """Initialize PyRIT modules in the background. This dramatically speeds up shell startup.""" - asyncio.run(self.context.initialize_async()) - self._init_complete.set() + try: + asyncio.run(self.context.initialize_async()) + except BaseException as exc: + self._init_error = exc + finally: + self._init_complete.set() + + def _raise_init_error(self) -> None: + """Re-raise background initialization failures on the calling thread.""" + if self._init_error is not None: + raise self._init_error def _ensure_initialized(self) -> None: """Wait for initialization to complete if not already done.""" @@ -97,14 +107,17 @@ def _ensure_initialized(self) -> None: print("Waiting for PyRIT initialization to complete...") sys.stdout.flush() self._init_complete.wait() + self._raise_init_error() def cmdloop(self, intro: Optional[str] = None) -> None: """Override cmdloop to play animated banner before starting the REPL.""" if intro is None: # Wait for background init to finish BEFORE animation, # so its log output doesn't interfere with cursor positioning - self._init_complete.wait() + self._ensure_initialized() intro = banner.play_animation(no_animation=self._no_animation) + elif self._init_complete.is_set(): + self._raise_init_error() self.intro = intro super().cmdloop(intro=self.intro) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index c70aa43a5..aa31532a6 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -6,9 +6,12 @@ """ import cmd +import threading from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch +import pytest + from pyrit.cli import _banner as banner from pyrit.cli import pyrit_shell @@ -33,6 +36,53 @@ def test_init(self): shell._init_complete.wait(timeout=2) mock_context.initialize_async.assert_called_once() + def test_background_init_failure_sets_event_and_raises_in_ensure_initialized(self): + """Test failed background initialization unblocks waiters and surfaces the original error.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context.initialize_async = AsyncMock(side_effect=RuntimeError("Initialization failed")) + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell._init_thread.join(timeout=2) + + assert shell._init_complete.is_set() + with pytest.raises(RuntimeError, match="Initialization failed"): + shell._ensure_initialized() + + def test_cmdloop_does_not_hang_when_background_init_fails(self): + """Test cmdloop surfaces background initialization failures instead of waiting forever.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context.initialize_async = AsyncMock(side_effect=RuntimeError("Initialization failed")) + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell._init_thread.join(timeout=2) + + errors: list[BaseException] = [] + + def run_cmdloop() -> None: + try: + shell.cmdloop() + except BaseException as exc: # pragma: no cover - assertion target + errors.append(exc) + + with ( + patch("pyrit.cli._banner.play_animation") as mock_play, + patch("cmd.Cmd.cmdloop") as mock_cmdloop, + ): + cmdloop_thread = threading.Thread(target=run_cmdloop, daemon=True) + cmdloop_thread.start() + cmdloop_thread.join(timeout=0.5) + + assert not cmdloop_thread.is_alive() + assert len(errors) == 1 + assert isinstance(errors[0], RuntimeError) + assert str(errors[0]) == "Initialization failed" + mock_play.assert_not_called() + mock_cmdloop.assert_not_called() + def test_prompt_and_intro(self): """Test shell prompt is set and cmdloop wires play_animation to intro.""" mock_context = MagicMock() From 821390f41658bf87e4ac5199637a67922386c0ef Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 09:12:20 -0700 Subject: [PATCH 02/14] Fix pyrit_shell startup: play animation without blocking on init cmdloop() was waiting for background initialization to complete before playing the banner animation, defeating the purpose of the background thread and causing ~14s startup instead of ~4s. Changes: - Play banner animation immediately while init runs in background - Suppress root logger during animation to prevent ANSI cursor corruption - Surface init errors after animation if init already finished - Add test verifying cmdloop does not block on slow init Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/pyrit_shell.py | 19 +++++++++++--- tests/unit/cli/test_pyrit_shell.py | 40 +++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 2efa5d2c0..6ce28d463 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -12,6 +12,7 @@ import asyncio import cmd +import logging import sys import threading from pathlib import Path @@ -112,10 +113,20 @@ def _ensure_initialized(self) -> None: def cmdloop(self, intro: Optional[str] = None) -> None: """Override cmdloop to play animated banner before starting the REPL.""" if intro is None: - # Wait for background init to finish BEFORE animation, - # so its log output doesn't interfere with cursor positioning - self._ensure_initialized() - intro = banner.play_animation(no_animation=self._no_animation) + # Play animation immediately while background init continues. + # Suppress logging during the animation so log lines don't corrupt + # the ANSI cursor-positioned frames. + root_logger = logging.getLogger() + prev_level = root_logger.level + root_logger.setLevel(logging.CRITICAL) + try: + intro = banner.play_animation(no_animation=self._no_animation) + finally: + root_logger.setLevel(prev_level) + + # If init already failed while the animation played, surface it now. + if self._init_complete.is_set(): + self._raise_init_error() elif self._init_complete.is_set(): self._raise_init_error() self.intro = intro diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index aa31532a6..9a9c3cefb 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -80,9 +80,47 @@ def run_cmdloop() -> None: assert len(errors) == 1 assert isinstance(errors[0], RuntimeError) assert str(errors[0]) == "Initialization failed" - mock_play.assert_not_called() + # Animation plays first (no blocking wait), then error is surfaced + mock_play.assert_called_once() mock_cmdloop.assert_not_called() + def test_cmdloop_does_not_block_on_slow_init(self): + """Test that cmdloop plays the animation immediately without waiting for initialization.""" + init_started = threading.Event() + init_release = threading.Event() + + async def slow_init() -> None: + init_started.set() + # Block until the test explicitly releases us + init_release.wait() + + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context.initialize_async = slow_init + + shell = pyrit_shell.PyRITShell(context=mock_context) + # Wait for background init to actually start running + init_started.wait(timeout=2) + + with ( + patch("pyrit.cli._banner.play_animation", return_value="BANNER") as mock_play, + patch("cmd.Cmd.cmdloop") as mock_cmdloop, + ): + cmdloop_thread = threading.Thread(target=shell.cmdloop, daemon=True) + cmdloop_thread.start() + cmdloop_thread.join(timeout=2) + + # cmdloop should have completed even though init is still running + assert not cmdloop_thread.is_alive() + mock_play.assert_called_once() + mock_cmdloop.assert_called_once() + assert not shell._init_complete.is_set() + + # Clean up: let the background init finish + init_release.set() + shell._init_thread.join(timeout=2) + def test_prompt_and_intro(self): """Test shell prompt is set and cmdloop wires play_animation to intro.""" mock_context = MagicMock() From 723f8e4192e7bf1a205c0645236ebfc81faaa5a0 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 09:38:31 -0700 Subject: [PATCH 03/14] Defer heavy frontend_core import to background thread for fast shell startup The module-level 'from pyrit.cli import frontend_core' took ~11s because it transitively imports pyrit.scenario. This blocked the shell from starting before all modules were loaded. Changes: - Move frontend_core import to TYPE_CHECKING (type hints only) - Background thread now handles: import frontend_core -> create context -> initialize_async (all ~14s of work off the main thread) - main() plays banner animation immediately, then creates shell with context_kwargs for deferred construction - Shell methods use lazy 'from pyrit.cli import frontend_core' which is free (sys.modules lookup) after the background thread imports it - PyRITShell.__init__ accepts either context= (tests) or context_kwargs= (CLI startup) for backward compatibility Result: module import drops from ~14s to ~3s. The prompt appears after the banner animation (~5s total) while init continues in background. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/pyrit_shell.py | 118 ++++++++++++++++++++--------- tests/unit/cli/test_pyrit_shell.py | 60 ++++++++------- 2 files changed, 116 insertions(+), 62 deletions(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 6ce28d463..3df5e67b0 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -16,13 +16,13 @@ import sys import threading from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: + from pyrit.cli import frontend_core from pyrit.models.scenario_result import ScenarioResult from pyrit.cli import _banner as banner -from pyrit.cli import frontend_core class PyRITShell(cmd.Cmd): @@ -62,26 +62,44 @@ class PyRITShell(cmd.Cmd): def __init__( self, *, - context: frontend_core.FrontendCore, + context: Optional[frontend_core.FrontendCore] = None, no_animation: bool = False, + context_kwargs: Optional[dict[str, Any]] = None, ) -> None: """ Initialize the PyRIT shell. + Accepts either a pre-created context (for testing or when imports are already done) + or context_kwargs for deferred creation in the background thread. + Args: - context: PyRIT context with loaded registries. + context: Pre-created PyRIT context. If provided, only ``initialize_async`` + runs in the background. no_animation: If True, skip the animated startup banner. + context_kwargs: Keyword arguments forwarded to ``FrontendCore()``. + When provided (and *context* is ``None``), the heavy + ``frontend_core`` import and context construction happen on + the background thread so the shell prompt appears immediately. """ super().__init__() - self.context = context self._no_animation = no_animation - self.default_database = context._database - self.default_log_level: Optional[int] = context._log_level - self.default_env_files = context._env_files + self._context_kwargs = context_kwargs # Track scenario execution history: list of (command_string, ScenarioResult) tuples self._scenario_history: list[tuple[str, ScenarioResult]] = [] + if context is not None: + self.context = context + self.default_database = context._database + self.default_log_level: Optional[int] = context._log_level + self.default_env_files = context._env_files + else: + # Will be set by the background thread after importing frontend_core. + self.context = None # type: ignore[assignment] + self.default_database = None + self.default_log_level = None + self.default_env_files = None + # Initialize PyRIT in background thread for faster startup. self._init_thread = threading.Thread(target=self._background_init, daemon=True) self._init_complete = threading.Event() @@ -89,8 +107,20 @@ def __init__( self._init_thread.start() def _background_init(self) -> None: - """Initialize PyRIT modules in the background. This dramatically speeds up shell startup.""" + """Import heavy modules and initialize PyRIT in the background. + + When *context_kwargs* were provided, this thread performs the expensive + ``from pyrit.cli import frontend_core`` import and creates the + ``FrontendCore`` context before calling ``initialize_async``. + """ try: + if self.context is None: + from pyrit.cli import frontend_core as fc + + self.context = fc.FrontendCore(**(self._context_kwargs or {})) + self.default_database = self.context._database + self.default_log_level = self.context._log_level + self.default_env_files = self.context._env_files asyncio.run(self.context.initialize_async()) except BaseException as exc: self._init_error = exc @@ -135,6 +165,8 @@ def cmdloop(self, intro: Optional[str] = None) -> None: def do_list_scenarios(self, arg: str) -> None: """List all available scenarios.""" self._ensure_initialized() + from pyrit.cli import frontend_core + try: asyncio.run(frontend_core.print_scenarios_list_async(context=self.context)) except Exception as e: @@ -143,6 +175,8 @@ def do_list_scenarios(self, arg: str) -> None: def do_list_initializers(self, arg: str) -> None: """List all available initializers.""" self._ensure_initialized() + from pyrit.cli import frontend_core + try: # Discover from scenarios directory by default (same as scan) discovery_path = frontend_core.get_default_initializer_discovery_path() @@ -194,6 +228,8 @@ def do_run(self, line: str) -> None: Initializers are specified per-run to allow different setups for different scenarios. """ self._ensure_initialized() + from pyrit.cli import frontend_core + if not line.strip(): print("Error: Specify a scenario name") print("\nUsage: run [options]") @@ -365,6 +401,9 @@ def do_print_scenario(self, arg: str) -> None: def do_help(self, arg: str) -> None: """Show help. Usage: help [command].""" if not arg: + self._ensure_initialized() + from pyrit.cli import frontend_core + # Show general help super().do_help(arg) print("\n" + "=" * 70) @@ -482,17 +521,20 @@ def main() -> int: parser.add_argument( "--config-file", type=Path, - help=frontend_core.ARG_HELP["config_file"], + help=( + "Path to a YAML configuration file. Allows specifying database, initializers, " + "initialization scripts, and env files. CLI arguments override config file values. " + "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." + ), ) parser.add_argument( "--database", - choices=[frontend_core.IN_MEMORY, frontend_core.SQLITE, frontend_core.AZURE_SQL], + choices=["InMemory", "SQLite", "AzureSQL"], default=None, help=( - f"Default database type to use" - f" ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL})" - f" (defaults to config file value, or {frontend_core.SQLITE} if not specified)" + "Default database type to use (InMemory, SQLite, AzureSQL)" + " (defaults to config file value, or SQLite if not specified)" ), ) @@ -523,29 +565,37 @@ def main() -> int: args = parser.parse_args() - # Resolve env files if provided - env_files = None + # Resolve env file paths (lightweight — no heavy imports needed). + env_files: Optional[list[Path]] = None if args.env_files: - try: - env_files = frontend_core.resolve_env_files(env_file_paths=args.env_files) - except ValueError as e: - print(f"Error: {e}") - return 1 - - # Create context (initializers are specified per-run, not at startup) - context = frontend_core.FrontendCore( - config_file=args.config_file, - database=args.database, - initialization_scripts=None, - initializer_names=None, - env_files=env_files, - log_level=args.log_level, - ) + env_files = [Path(p).resolve() for p in args.env_files] + + # Play the banner immediately, before heavy imports. + # Suppress logging so background-thread output doesn't corrupt the animation. + root_logger = logging.getLogger() + prev_level = root_logger.level + root_logger.setLevel(logging.CRITICAL) + try: + intro = banner.play_animation(no_animation=args.no_animation) + finally: + root_logger.setLevel(prev_level) - # Start shell + # Create shell with deferred initialization — the background thread + # will import frontend_core, create the FrontendCore context, and call + # initialize_async while the user is already at the prompt. try: - shell = PyRITShell(context=context, no_animation=args.no_animation) - shell.cmdloop() + shell = PyRITShell( + no_animation=args.no_animation, + context_kwargs={ + "config_file": args.config_file, + "database": args.database, + "initialization_scripts": None, + "initializer_names": None, + "env_files": env_files, + "log_level": args.log_level, + }, + ) + shell.cmdloop(intro=intro) return 0 except KeyboardInterrupt: print("\n\nInterrupted. Goodbye!") diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 9a9c3cefb..cfb5fde6b 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -729,8 +729,8 @@ class TestMain: """Tests for main function.""" @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_default_args(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock): + @patch("pyrit.cli._banner.play_animation", return_value="") + def test_main_default_args(self, mock_play: MagicMock, mock_shell_class: MagicMock): """Test main with default arguments.""" mock_shell = MagicMock() mock_shell_class.return_value = mock_shell @@ -739,15 +739,15 @@ def test_main_default_args(self, mock_frontend_core: MagicMock, mock_shell_class result = pyrit_shell.main() assert result == 0 - mock_frontend_core.assert_called_once() - call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["database"] is None - assert call_kwargs["log_level"] == "WARNING" + call_kwargs = mock_shell_class.call_args[1] + ctx_kw = call_kwargs["context_kwargs"] + assert ctx_kw["database"] is None + assert ctx_kw["log_level"] == "WARNING" mock_shell.cmdloop.assert_called_once() @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_with_database_arg(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock): + @patch("pyrit.cli._banner.play_animation", return_value="") + def test_main_with_database_arg(self, mock_play: MagicMock, mock_shell_class: MagicMock): """Test main with database argument.""" mock_shell = MagicMock() mock_shell_class.return_value = mock_shell @@ -756,12 +756,12 @@ def test_main_with_database_arg(self, mock_frontend_core: MagicMock, mock_shell_ result = pyrit_shell.main() assert result == 0 - call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["database"] == "InMemory" + ctx_kw = mock_shell_class.call_args[1]["context_kwargs"] + assert ctx_kw["database"] == "InMemory" @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_with_log_level_arg(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock): + @patch("pyrit.cli._banner.play_animation", return_value="") + def test_main_with_log_level_arg(self, mock_play: MagicMock, mock_shell_class: MagicMock): """Test main with log-level argument.""" mock_shell = MagicMock() mock_shell_class.return_value = mock_shell @@ -770,12 +770,12 @@ def test_main_with_log_level_arg(self, mock_frontend_core: MagicMock, mock_shell result = pyrit_shell.main() assert result == 0 - call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["log_level"] == "DEBUG" + ctx_kw = mock_shell_class.call_args[1]["context_kwargs"] + assert ctx_kw["log_level"] == "DEBUG" @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_with_keyboard_interrupt(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock, capsys): + @patch("pyrit.cli._banner.play_animation", return_value="") + def test_main_with_keyboard_interrupt(self, mock_play: MagicMock, mock_shell_class: MagicMock, capsys): """Test main handles keyboard interrupt.""" mock_shell = MagicMock() mock_shell.cmdloop.side_effect = KeyboardInterrupt() @@ -789,8 +789,8 @@ def test_main_with_keyboard_interrupt(self, mock_frontend_core: MagicMock, mock_ assert "Interrupted" in captured.out @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_with_exception(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock, capsys): + @patch("pyrit.cli._banner.play_animation", return_value="") + def test_main_with_exception(self, mock_play: MagicMock, mock_shell_class: MagicMock, capsys): """Test main handles exceptions.""" mock_shell = MagicMock() mock_shell.cmdloop.side_effect = ValueError("Test error") @@ -803,19 +803,23 @@ def test_main_with_exception(self, mock_frontend_core: MagicMock, mock_shell_cla captured = capsys.readouterr() assert "Error:" in captured.out - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_creates_context_without_initializers(self, mock_frontend_core: MagicMock): + @patch("pyrit.cli.pyrit_shell.PyRITShell") + @patch("pyrit.cli._banner.play_animation", return_value="") + def test_main_creates_context_without_initializers(self, mock_play: MagicMock, mock_shell_class: MagicMock): """Test main creates context without initializers.""" - with patch("pyrit.cli.pyrit_shell.PyRITShell"), patch("sys.argv", ["pyrit_shell"]): + mock_shell = MagicMock() + mock_shell_class.return_value = mock_shell + + with patch("sys.argv", ["pyrit_shell"]): pyrit_shell.main() - call_kwargs = mock_frontend_core.call_args[1] - assert call_kwargs["initialization_scripts"] is None - assert call_kwargs["initializer_names"] is None + ctx_kw = mock_shell_class.call_args[1]["context_kwargs"] + assert ctx_kw["initialization_scripts"] is None + assert ctx_kw["initializer_names"] is None @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_with_no_animation_flag(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock): + @patch("pyrit.cli._banner.play_animation", return_value="") + def test_main_with_no_animation_flag(self, mock_play: MagicMock, mock_shell_class: MagicMock): """Test main passes --no-animation flag to PyRITShell.""" mock_shell = MagicMock() mock_shell_class.return_value = mock_shell @@ -828,8 +832,8 @@ def test_main_with_no_animation_flag(self, mock_frontend_core: MagicMock, mock_s assert call_kwargs["no_animation"] is True @patch("pyrit.cli.pyrit_shell.PyRITShell") - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_main_default_animation_enabled(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock): + @patch("pyrit.cli._banner.play_animation", return_value="") + def test_main_default_animation_enabled(self, mock_play: MagicMock, mock_shell_class: MagicMock): """Test main defaults to animation enabled (no_animation=False).""" mock_shell = MagicMock() mock_shell_class.return_value = mock_shell From 1723c17ee3f53bd7b6cfe6e6bc002f633a03dbbd Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 09:42:48 -0700 Subject: [PATCH 04/14] Add integration test for shell startup performance (<5s) Runs pyrit_shell module import in a subprocess and asserts it completes within 5 seconds. Guards against regressions from heavy top-level imports that would delay the banner animation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/integration/cli/__init__.py | 2 + .../cli/test_pyrit_shell_startup.py | 50 +++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 tests/integration/cli/__init__.py create mode 100644 tests/integration/cli/test_pyrit_shell_startup.py diff --git a/tests/integration/cli/__init__.py b/tests/integration/cli/__init__.py new file mode 100644 index 000000000..9a0454564 --- /dev/null +++ b/tests/integration/cli/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/integration/cli/test_pyrit_shell_startup.py b/tests/integration/cli/test_pyrit_shell_startup.py new file mode 100644 index 000000000..a16fcaf8b --- /dev/null +++ b/tests/integration/cli/test_pyrit_shell_startup.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Integration test for pyrit_shell startup performance. + +The shell uses a background thread to import heavy modules (frontend_core, +pyrit.scenario, etc.) so the banner animation can start quickly. If the +module-level imports regress, the animation will be blocked and the shell +will appear to hang on startup. +""" + +import subprocess +import sys + +# Maximum acceptable time (seconds) for `from pyrit.cli import pyrit_shell` +# to complete. The banner animation starts immediately after this import, +# so keeping it under 5 s is critical for perceived startup speed. +_MAX_IMPORT_SECONDS = 6 + + +def test_pyrit_shell_module_imports_within_budget() -> None: + """Importing pyrit.cli.pyrit_shell must complete in under 6 seconds. + + This guards against accidentally adding heavy top-level imports + (e.g. ``from pyrit.cli import frontend_core``) that would block + the main thread and delay the banner animation. + """ + script = ( + "import time; " + "t = time.perf_counter(); " + "from pyrit.cli import pyrit_shell; " # noqa: F401 + "elapsed = time.perf_counter() - t; " + "print(f'{elapsed:.2f}'); " + f"raise SystemExit(0 if elapsed < {_MAX_IMPORT_SECONDS} else 1)" + ) + + result = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=30, + ) + + elapsed_str = result.stdout.strip() + assert result.returncode == 0, ( + f"pyrit_shell module import took {elapsed_str}s, " + f"exceeding the {_MAX_IMPORT_SECONDS}s budget. " + "Check for heavy top-level imports in pyrit_shell.py." + ) From df8bd9daa6bf2327361eaa03f9d78fdab257120e Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 12:07:35 -0700 Subject: [PATCH 05/14] Increase cmdloop deadlock test timeout to 5s for CI stability Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/unit/cli/test_pyrit_shell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index cfb5fde6b..ebe13bc18 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -74,7 +74,7 @@ def run_cmdloop() -> None: ): cmdloop_thread = threading.Thread(target=run_cmdloop, daemon=True) cmdloop_thread.start() - cmdloop_thread.join(timeout=0.5) + cmdloop_thread.join(timeout=5) assert not cmdloop_thread.is_alive() assert len(errors) == 1 From c8f11fa02a71d173838cda6978aac2dea5fc0746 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 12:12:01 -0700 Subject: [PATCH 06/14] Validate context/context_kwargs exclusivity, add context_kwargs tests - Raise ValueError if both context and context_kwargs are provided - Add test for context_kwargs background initialization path - Add test for the mutual exclusivity validation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/pyrit_shell.py | 10 +++++++++- tests/unit/cli/test_pyrit_shell.py | 25 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 3df5e67b0..0e6100ed2 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -80,9 +80,16 @@ def __init__( When provided (and *context* is ``None``), the heavy ``frontend_core`` import and context construction happen on the background thread so the shell prompt appears immediately. + + Raises: + ValueError: If both *context* and *context_kwargs* are provided. """ super().__init__() self._no_animation = no_animation + + if context is not None and context_kwargs is not None: + raise ValueError("Cannot specify both context and context_kwargs") + self._context_kwargs = context_kwargs # Track scenario execution history: list of (command_string, ScenarioResult) tuples @@ -107,7 +114,8 @@ def __init__( self._init_thread.start() def _background_init(self) -> None: - """Import heavy modules and initialize PyRIT in the background. + """ + Import heavy modules and initialize PyRIT in the background. When *context_kwargs* were provided, this thread performs the expensive ``from pyrit.cli import frontend_core`` import and creates the diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index ebe13bc18..6ef60646a 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -50,6 +50,31 @@ def test_background_init_failure_sets_event_and_raises_in_ensure_initialized(sel with pytest.raises(RuntimeError, match="Initialization failed"): shell._ensure_initialized() + def test_init_rejects_both_context_and_context_kwargs(self): + """Test that passing both context and context_kwargs raises ValueError.""" + mock_context = MagicMock() + with pytest.raises(ValueError, match="Cannot specify both"): + pyrit_shell.PyRITShell(context=mock_context, context_kwargs={"database": "SQLite"}) + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_init_with_context_kwargs_creates_context_in_background(self, mock_fc_class: MagicMock): + """Test that context_kwargs defers FrontendCore creation to the background thread.""" + mock_fc = MagicMock() + mock_fc._database = "InMemory" + mock_fc._log_level = "WARNING" + mock_fc._env_files = None + mock_fc.initialize_async = AsyncMock() + mock_fc_class.return_value = mock_fc + + shell = pyrit_shell.PyRITShell(context_kwargs={"database": "InMemory", "log_level": "WARNING"}) + shell._init_thread.join(timeout=5) + + assert shell._init_complete.is_set() + assert shell.context is mock_fc + assert shell.default_database == "InMemory" + mock_fc_class.assert_called_once_with(database="InMemory", log_level="WARNING") + mock_fc.initialize_async.assert_called_once() + def test_cmdloop_does_not_hang_when_background_init_fails(self): """Test cmdloop surfaces background initialization failures instead of waiting forever.""" mock_context = MagicMock() From fa2ee1ff73f11d278ede21d176b1198912ddbf43 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 13:02:34 -0700 Subject: [PATCH 07/14] fix: widen default_env_files type to Sequence[Path] for mypy The _resolve_env_files() method returns Optional[Sequence[Path]], not Optional[list[Path]]. Widen the type annotation and move the import into TYPE_CHECKING to satisfy both mypy and ruff TC003. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/pyrit_shell.py | 73 ++--- tests/unit/cli/test_pyrit_shell.py | 414 +++++++++++++---------------- 2 files changed, 214 insertions(+), 273 deletions(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 0e6100ed2..2d4f7ee70 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -19,6 +19,8 @@ from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: + from collections.abc import Sequence + from pyrit.cli import frontend_core from pyrit.models.scenario_result import ScenarioResult @@ -62,50 +64,32 @@ class PyRITShell(cmd.Cmd): def __init__( self, *, - context: Optional[frontend_core.FrontendCore] = None, no_animation: bool = False, - context_kwargs: Optional[dict[str, Any]] = None, + **context_kwargs: Any, ) -> None: """ Initialize the PyRIT shell. - Accepts either a pre-created context (for testing or when imports are already done) - or context_kwargs for deferred creation in the background thread. + The heavy ``frontend_core`` import, ``FrontendCore`` construction, and + ``initialize_async`` call all happen on a background thread so the + shell prompt appears immediately. Args: - context: Pre-created PyRIT context. If provided, only ``initialize_async`` - runs in the background. no_animation: If True, skip the animated startup banner. - context_kwargs: Keyword arguments forwarded to ``FrontendCore()``. - When provided (and *context* is ``None``), the heavy - ``frontend_core`` import and context construction happen on - the background thread so the shell prompt appears immediately. - - Raises: - ValueError: If both *context* and *context_kwargs* are provided. + **context_kwargs: Keyword arguments forwarded to ``FrontendCore()``. """ super().__init__() self._no_animation = no_animation - - if context is not None and context_kwargs is not None: - raise ValueError("Cannot specify both context and context_kwargs") - self._context_kwargs = context_kwargs # Track scenario execution history: list of (command_string, ScenarioResult) tuples self._scenario_history: list[tuple[str, ScenarioResult]] = [] - if context is not None: - self.context = context - self.default_database = context._database - self.default_log_level: Optional[int] = context._log_level - self.default_env_files = context._env_files - else: - # Will be set by the background thread after importing frontend_core. - self.context = None # type: ignore[assignment] - self.default_database = None - self.default_log_level = None - self.default_env_files = None + # Set by the background thread after importing frontend_core. + self.context: Optional[frontend_core.FrontendCore] = None + self.default_database: Optional[str] = None + self.default_log_level: Optional[int] = None + self.default_env_files: Optional[Sequence[Path]] = None # Initialize PyRIT in background thread for faster startup. self._init_thread = threading.Thread(target=self._background_init, daemon=True) @@ -114,21 +98,14 @@ def __init__( self._init_thread.start() def _background_init(self) -> None: - """ - Import heavy modules and initialize PyRIT in the background. - - When *context_kwargs* were provided, this thread performs the expensive - ``from pyrit.cli import frontend_core`` import and creates the - ``FrontendCore`` context before calling ``initialize_async``. - """ + """Import heavy modules and initialize PyRIT in the background.""" try: - if self.context is None: - from pyrit.cli import frontend_core as fc + from pyrit.cli import frontend_core as fc - self.context = fc.FrontendCore(**(self._context_kwargs or {})) - self.default_database = self.context._database - self.default_log_level = self.context._log_level - self.default_env_files = self.context._env_files + self.context = fc.FrontendCore(**self._context_kwargs) + self.default_database = self.context._database + self.default_log_level = self.context._log_level + self.default_env_files = self.context._env_files asyncio.run(self.context.initialize_async()) except BaseException as exc: self._init_error = exc @@ -594,14 +571,12 @@ def main() -> int: try: shell = PyRITShell( no_animation=args.no_animation, - context_kwargs={ - "config_file": args.config_file, - "database": args.database, - "initialization_scripts": None, - "initializer_names": None, - "env_files": env_files, - "log_level": args.log_level, - }, + config_file=args.config_file, + database=args.database, + initialization_scripts=None, + initializer_names=None, + env_files=env_files, + log_level=args.log_level, ) shell.cmdloop(intro=intro) return 0 diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 6ef60646a..a2ed17696 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -16,73 +16,57 @@ from pyrit.cli import pyrit_shell +@pytest.fixture() +def mock_fc(): + """Patch FrontendCore so the background thread uses a controllable mock context.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context._env_files = None + mock_context._scenario_registry = MagicMock() + mock_context._initializer_registry = MagicMock() + mock_context.initialize_async = AsyncMock() + + with patch("pyrit.cli.frontend_core.FrontendCore", return_value=mock_context) as mock_fc_class: + yield mock_context, mock_fc_class + + class TestPyRITShell: """Tests for PyRITShell class.""" - def test_init(self): + def test_init(self, mock_fc): """Test PyRITShell initialization.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context.initialize_async = AsyncMock() + ctx, mock_fc_class = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) - assert shell.context == mock_context + assert shell._init_complete.is_set() + assert shell.context is ctx assert shell.default_database == "SQLite" assert shell.default_log_level == "WARNING" assert shell._scenario_history == [] - # Initialize is called in a background thread, so we need to wait for it - shell._init_complete.wait(timeout=2) - mock_context.initialize_async.assert_called_once() + mock_fc_class.assert_called_once_with() + ctx.initialize_async.assert_called_once() - def test_background_init_failure_sets_event_and_raises_in_ensure_initialized(self): + def test_background_init_failure_sets_event_and_raises_in_ensure_initialized(self, mock_fc): """Test failed background initialization unblocks waiters and surfaces the original error.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context.initialize_async = AsyncMock(side_effect=RuntimeError("Initialization failed")) + ctx, _ = mock_fc + ctx.initialize_async = AsyncMock(side_effect=RuntimeError("Initialization failed")) - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() shell._init_thread.join(timeout=2) assert shell._init_complete.is_set() with pytest.raises(RuntimeError, match="Initialization failed"): shell._ensure_initialized() - def test_init_rejects_both_context_and_context_kwargs(self): - """Test that passing both context and context_kwargs raises ValueError.""" - mock_context = MagicMock() - with pytest.raises(ValueError, match="Cannot specify both"): - pyrit_shell.PyRITShell(context=mock_context, context_kwargs={"database": "SQLite"}) - - @patch("pyrit.cli.frontend_core.FrontendCore") - def test_init_with_context_kwargs_creates_context_in_background(self, mock_fc_class: MagicMock): - """Test that context_kwargs defers FrontendCore creation to the background thread.""" - mock_fc = MagicMock() - mock_fc._database = "InMemory" - mock_fc._log_level = "WARNING" - mock_fc._env_files = None - mock_fc.initialize_async = AsyncMock() - mock_fc_class.return_value = mock_fc - - shell = pyrit_shell.PyRITShell(context_kwargs={"database": "InMemory", "log_level": "WARNING"}) - shell._init_thread.join(timeout=5) - - assert shell._init_complete.is_set() - assert shell.context is mock_fc - assert shell.default_database == "InMemory" - mock_fc_class.assert_called_once_with(database="InMemory", log_level="WARNING") - mock_fc.initialize_async.assert_called_once() - - def test_cmdloop_does_not_hang_when_background_init_fails(self): + def test_cmdloop_does_not_hang_when_background_init_fails(self, mock_fc): """Test cmdloop surfaces background initialization failures instead of waiting forever.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context.initialize_async = AsyncMock(side_effect=RuntimeError("Initialization failed")) + ctx, _ = mock_fc + ctx.initialize_async = AsyncMock(side_effect=RuntimeError("Initialization failed")) - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() shell._init_thread.join(timeout=2) errors: list[BaseException] = [] @@ -109,8 +93,9 @@ def run_cmdloop() -> None: mock_play.assert_called_once() mock_cmdloop.assert_not_called() - def test_cmdloop_does_not_block_on_slow_init(self): + def test_cmdloop_does_not_block_on_slow_init(self, mock_fc): """Test that cmdloop plays the animation immediately without waiting for initialization.""" + ctx, _ = mock_fc init_started = threading.Event() init_release = threading.Event() @@ -119,12 +104,9 @@ async def slow_init() -> None: # Block until the test explicitly releases us init_release.wait() - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context.initialize_async = slow_init + ctx.initialize_async = slow_init - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() # Wait for background init to actually start running init_started.wait(timeout=2) @@ -146,12 +128,12 @@ async def slow_init() -> None: init_release.set() shell._init_thread.join(timeout=2) - def test_prompt_and_intro(self): + def test_prompt_and_intro(self, mock_fc): """Test shell prompt is set and cmdloop wires play_animation to intro.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) assert shell.prompt == "pyrit> " @@ -165,12 +147,12 @@ def test_prompt_and_intro(self): mock_play.assert_called_once_with(no_animation=shell._no_animation) mock_cmdloop.assert_called_once_with(intro="TEST_BANNER") - def test_cmdloop_honors_explicit_intro(self): + def test_cmdloop_honors_explicit_intro(self, mock_fc): """Test that cmdloop passes through a non-None intro without calling play_animation.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) with patch("pyrit.cli._banner.play_animation") as mock_play, patch("cmd.Cmd.cmdloop") as mock_cmdloop: shell.cmdloop(intro="Custom intro") @@ -179,24 +161,24 @@ def test_cmdloop_honors_explicit_intro(self): mock_cmdloop.assert_called_once_with(intro="Custom intro") @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - def test_do_list_scenarios(self, mock_print_scenarios: AsyncMock): + def test_do_list_scenarios(self, mock_print_scenarios: AsyncMock, mock_fc): """Test do_list_scenarios command.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_list_scenarios("") - mock_print_scenarios.assert_called_once_with(context=mock_context) + mock_print_scenarios.assert_called_once_with(context=ctx) @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, capsys): + def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, mock_fc, capsys): """Test do_list_scenarios handles exceptions.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_print_scenarios.side_effect = ValueError("Test error") - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_list_scenarios("") captured = capsys.readouterr() @@ -204,25 +186,25 @@ def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, @patch("pyrit.cli.frontend_core.get_default_initializer_discovery_path") @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - def test_do_list_initializers(self, mock_print_initializers: AsyncMock, mock_get_path: MagicMock): + def test_do_list_initializers(self, mock_print_initializers: AsyncMock, mock_get_path: MagicMock, mock_fc): """Test do_list_initializers command.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_path = Path("/test/path") mock_get_path.return_value = mock_path - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_list_initializers("") - mock_print_initializers.assert_called_once_with(context=mock_context, discovery_path=mock_path) + mock_print_initializers.assert_called_once_with(context=ctx, discovery_path=mock_path) @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - def test_do_list_initializers_with_path(self, mock_print_initializers: AsyncMock): + def test_do_list_initializers_with_path(self, mock_print_initializers: AsyncMock, mock_fc): """Test do_list_initializers with custom path.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_list_initializers("/custom/path") assert mock_print_initializers.call_count == 1 @@ -230,24 +212,24 @@ def test_do_list_initializers_with_path(self, mock_print_initializers: AsyncMock assert isinstance(call_kwargs["discovery_path"], Path) @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - def test_do_list_initializers_with_exception(self, mock_print_initializers: AsyncMock, capsys): + def test_do_list_initializers_with_exception(self, mock_print_initializers: AsyncMock, mock_fc, capsys): """Test do_list_initializers handles exceptions.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_print_initializers.side_effect = ValueError("Test error") - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_list_initializers("") captured = capsys.readouterr() assert "Error listing initializers" in captured.out - def test_do_run_empty_line(self, capsys): + def test_do_run_empty_line(self, mock_fc, capsys): """Test do_run with empty line.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_run("") captured = capsys.readouterr() @@ -261,14 +243,10 @@ def test_do_run_basic_scenario( mock_parse_args: MagicMock, _mock_run_scenario: AsyncMock, mock_asyncio_run: MagicMock, + mock_fc, ): """Test do_run with basic scenario.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context._scenario_registry = MagicMock() - mock_context._initializer_registry = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -289,7 +267,8 @@ def test_do_run_basic_scenario( # First call is background init, second call is the actual test mock_asyncio_run.side_effect = [None, mock_result] - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_run("test_scenario --initializers test_init") mock_parse_args.assert_called_once() @@ -301,13 +280,13 @@ def test_do_run_basic_scenario( assert shell._scenario_history[0][1] == mock_result @patch("pyrit.cli.frontend_core.parse_run_arguments") - def test_do_run_parse_error(self, mock_parse_args: MagicMock, capsys): + def test_do_run_parse_error(self, mock_parse_args: MagicMock, mock_fc, capsys): """Test do_run with parse error.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_parse_args.side_effect = ValueError("Parse error") - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_run("test_scenario --invalid") captured = capsys.readouterr() @@ -323,14 +302,10 @@ def test_do_run_with_initialization_scripts( mock_parse_args: MagicMock, mock_run_scenario: AsyncMock, mock_asyncio_run: MagicMock, + mock_fc, ): """Test do_run with initialization scripts.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context._scenario_registry = MagicMock() - mock_context._initializer_registry = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -351,7 +326,8 @@ def test_do_run_with_initialization_scripts( # First call is background init, second call is the actual test mock_asyncio_run.side_effect = [None, MagicMock()] - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_run("test_scenario --initialization-scripts script.py") mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) @@ -363,11 +339,11 @@ def test_do_run_with_missing_script( self, mock_resolve_scripts: MagicMock, mock_parse_args: MagicMock, + mock_fc, capsys, ): """Test do_run with missing initialization script.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -386,7 +362,8 @@ def test_do_run_with_missing_script( mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_run("test_scenario --initialization-scripts missing.py") captured = capsys.readouterr() @@ -398,14 +375,10 @@ def test_do_run_with_database_override( self, mock_parse_args: MagicMock, mock_asyncio_run: MagicMock, + mock_fc, ): """Test do_run with database override.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context._scenario_registry = MagicMock() - mock_context._initializer_registry = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -425,7 +398,8 @@ def test_do_run_with_database_override( # First call is background init, second call is the actual test mock_asyncio_run.side_effect = [None, MagicMock()] - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) with patch("pyrit.cli.frontend_core.FrontendCore") as mock_frontend: shell.do_run("test_scenario --initializers test_init --database InMemory") @@ -440,15 +414,11 @@ def test_do_run_with_exception( self, mock_parse_args: MagicMock, mock_asyncio_run: MagicMock, + mock_fc, capsys, ): """Test do_run handles exceptions during scenario run.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context._scenario_registry = MagicMock() - mock_context._initializer_registry = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -468,29 +438,30 @@ def test_do_run_with_exception( # First call succeeds (background init), second call raises error (the actual test) mock_asyncio_run.side_effect = [None, ValueError("Test error")] - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_run("test_scenario --initializers test_init") captured = capsys.readouterr() assert "Error: Test error" in captured.out - def test_do_scenario_history_empty(self, capsys): + def test_do_scenario_history_empty(self, mock_fc, capsys): """Test do_scenario_history with no history.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_scenario_history("") captured = capsys.readouterr() assert "No scenario runs in history" in captured.out - def test_do_scenario_history_with_runs(self, capsys): + def test_do_scenario_history_with_runs(self, mock_fc, capsys): """Test do_scenario_history with scenario runs.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell._scenario_history = [ ("test_scenario1 --initializers init1", MagicMock()), ("test_scenario2 --initializers init2", MagicMock()), @@ -504,12 +475,12 @@ def test_do_scenario_history_with_runs(self, capsys): assert "test_scenario2" in captured.out assert "Total runs: 2" in captured.out - def test_do_print_scenario_empty(self, capsys): + def test_do_print_scenario_empty(self, mock_fc, capsys): """Test do_print_scenario with no history.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.do_print_scenario("") captured = capsys.readouterr() @@ -521,15 +492,16 @@ def test_do_print_scenario_all( self, mock_printer_class: MagicMock, mock_asyncio_run: MagicMock, + mock_fc, capsys, ): """Test do_print_scenario without argument prints all.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_printer = MagicMock() mock_printer_class.return_value = mock_printer - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell._scenario_history = [ ("test_scenario1", MagicMock()), ("test_scenario2", MagicMock()), @@ -548,15 +520,16 @@ def test_do_print_scenario_specific( self, mock_printer_class: MagicMock, mock_asyncio_run: MagicMock, + mock_fc, capsys, ): """Test do_print_scenario with specific scenario number.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_printer = MagicMock() mock_printer_class.return_value = mock_printer - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell._scenario_history = [ ("test_scenario1", MagicMock()), ("test_scenario2", MagicMock()), @@ -569,12 +542,12 @@ def test_do_print_scenario_specific( # 1 background init + 1 print call assert mock_asyncio_run.call_count == 2 - def test_do_print_scenario_invalid_number(self, capsys): + def test_do_print_scenario_invalid_number(self, mock_fc, capsys): """Test do_print_scenario with invalid scenario number.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell._scenario_history = [ ("test_scenario1", MagicMock()), ] @@ -584,12 +557,12 @@ def test_do_print_scenario_invalid_number(self, capsys): captured = capsys.readouterr() assert "must be between 1 and 1" in captured.out - def test_do_print_scenario_non_integer(self, capsys): + def test_do_print_scenario_non_integer(self, mock_fc, capsys): """Test do_print_scenario with non-integer argument.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell._scenario_history = [ ("test_scenario1", MagicMock()), ] @@ -599,12 +572,12 @@ def test_do_print_scenario_non_integer(self, capsys): captured = capsys.readouterr() assert "Invalid scenario number" in captured.out - def test_do_help_without_arg(self, capsys): + def test_do_help_without_arg(self, mock_fc, capsys): """Test do_help without argument.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) # Capture help output with patch("cmd.Cmd.do_help"): @@ -612,12 +585,12 @@ def test_do_help_without_arg(self, capsys): captured = capsys.readouterr() assert "Shell Startup Options" in captured.out - def test_do_help_with_arg(self): + def test_do_help_with_arg(self, mock_fc): """Test do_help with specific command.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) with patch("cmd.Cmd.do_help") as mock_parent_help: shell.do_help("run") @@ -625,14 +598,14 @@ def test_do_help_with_arg(self): @patch.object(cmd.Cmd, "cmdloop") @patch.object(banner, "play_animation") - def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_cmdloop: MagicMock): + def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_cmdloop: MagicMock, mock_fc): """Test cmdloop wires banner.play_animation into intro and threads --no-animation.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_play.return_value = "animated banner" - shell = pyrit_shell.PyRITShell(context=mock_context, no_animation=True) + shell = pyrit_shell.PyRITShell(no_animation=True) + shell._init_thread.join(timeout=5) shell.cmdloop() mock_play.assert_called_once_with(no_animation=True) @@ -640,96 +613,96 @@ def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_ mock_cmdloop.assert_called_once_with(intro="animated banner") @patch.object(cmd.Cmd, "cmdloop") - def test_cmdloop_honors_explicit_intro(self, mock_cmdloop: MagicMock): + def test_cmdloop_honors_explicit_intro(self, mock_cmdloop: MagicMock, mock_fc): """Test cmdloop honors a non-None intro argument without calling play_animation.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.cmdloop(intro="custom intro") assert shell.intro == "custom intro" mock_cmdloop.assert_called_once_with(intro="custom intro") - def test_do_exit(self, capsys): + def test_do_exit(self, mock_fc, capsys): """Test do_exit command.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) result = shell.do_exit("") assert result is True captured = capsys.readouterr() assert "Goodbye" in captured.out - def test_do_quit_alias(self): + def test_do_quit_alias(self, mock_fc): """Test do_quit is alias for do_exit.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) assert shell.do_quit == shell.do_exit - def test_do_q_alias(self): + def test_do_q_alias(self, mock_fc): """Test do_q is alias for do_exit.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) assert shell.do_q == shell.do_exit - def test_do_eof_alias(self): + def test_do_eof_alias(self, mock_fc): """Test do_EOF is alias for do_exit.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) assert shell.do_EOF == shell.do_exit @patch("os.system") - def test_do_clear_windows(self, mock_system: MagicMock): + def test_do_clear_windows(self, mock_system: MagicMock, mock_fc): """Test do_clear on Windows.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) with patch("os.name", "nt"): shell.do_clear("") mock_system.assert_called_with("cls") @patch("os.system") - def test_do_clear_unix(self, mock_system: MagicMock): + def test_do_clear_unix(self, mock_system: MagicMock, mock_fc): """Test do_clear on Unix.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) with patch("os.name", "posix"): shell.do_clear("") mock_system.assert_called_with("clear") - def test_emptyline(self): + def test_emptyline(self, mock_fc): """Test emptyline doesn't repeat last command.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) result = shell.emptyline() assert result is False - def test_default_with_hyphen_to_underscore(self): + def test_default_with_hyphen_to_underscore(self, mock_fc): """Test default converts hyphens to underscores.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) # Mock a method with underscores shell.do_list_scenarios = MagicMock() @@ -738,12 +711,12 @@ def test_default_with_hyphen_to_underscore(self): shell.do_list_scenarios.assert_called_once_with("") - def test_default_unknown_command(self, capsys): + def test_default_unknown_command(self, mock_fc, capsys): """Test default with unknown command.""" - mock_context = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) shell.default("unknown_command") captured = capsys.readouterr() @@ -765,9 +738,8 @@ def test_main_default_args(self, mock_play: MagicMock, mock_shell_class: MagicMo assert result == 0 call_kwargs = mock_shell_class.call_args[1] - ctx_kw = call_kwargs["context_kwargs"] - assert ctx_kw["database"] is None - assert ctx_kw["log_level"] == "WARNING" + assert call_kwargs["database"] is None + assert call_kwargs["log_level"] == "WARNING" mock_shell.cmdloop.assert_called_once() @patch("pyrit.cli.pyrit_shell.PyRITShell") @@ -781,8 +753,8 @@ def test_main_with_database_arg(self, mock_play: MagicMock, mock_shell_class: Ma result = pyrit_shell.main() assert result == 0 - ctx_kw = mock_shell_class.call_args[1]["context_kwargs"] - assert ctx_kw["database"] == "InMemory" + call_kwargs = mock_shell_class.call_args[1] + assert call_kwargs["database"] == "InMemory" @patch("pyrit.cli.pyrit_shell.PyRITShell") @patch("pyrit.cli._banner.play_animation", return_value="") @@ -795,8 +767,8 @@ def test_main_with_log_level_arg(self, mock_play: MagicMock, mock_shell_class: M result = pyrit_shell.main() assert result == 0 - ctx_kw = mock_shell_class.call_args[1]["context_kwargs"] - assert ctx_kw["log_level"] == "DEBUG" + call_kwargs = mock_shell_class.call_args[1] + assert call_kwargs["log_level"] == "DEBUG" @patch("pyrit.cli.pyrit_shell.PyRITShell") @patch("pyrit.cli._banner.play_animation", return_value="") @@ -838,9 +810,9 @@ def test_main_creates_context_without_initializers(self, mock_play: MagicMock, m with patch("sys.argv", ["pyrit_shell"]): pyrit_shell.main() - ctx_kw = mock_shell_class.call_args[1]["context_kwargs"] - assert ctx_kw["initialization_scripts"] is None - assert ctx_kw["initializer_names"] is None + call_kwargs = mock_shell_class.call_args[1] + assert call_kwargs["initialization_scripts"] is None + assert call_kwargs["initializer_names"] is None @patch("pyrit.cli.pyrit_shell.PyRITShell") @patch("pyrit.cli._banner.play_animation", return_value="") @@ -880,14 +852,10 @@ def test_run_with_all_parameters( self, mock_parse_args: MagicMock, mock_asyncio_run: MagicMock, + mock_fc, ): """Test run command with all parameters.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context._scenario_registry = MagicMock() - mock_context._initializer_registry = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -907,7 +875,8 @@ def test_run_with_all_parameters( # First call is background init, second call is the actual test mock_asyncio_run.side_effect = [None, MagicMock()] - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) with patch("pyrit.cli.frontend_core.FrontendCore"), patch("pyrit.cli.frontend_core.run_scenario_async"): shell.do_run("test_scenario --initializers init1 --strategies s1 s2 --max-concurrency 10") @@ -922,14 +891,10 @@ def test_run_stores_result_in_history( self, mock_parse_args: MagicMock, mock_asyncio_run: MagicMock, + mock_fc, ): """Test run command stores result in history.""" - mock_context = MagicMock() - mock_context._database = "SQLite" - mock_context._log_level = "WARNING" - mock_context._scenario_registry = MagicMock() - mock_context._initializer_registry = MagicMock() - mock_context.initialize_async = AsyncMock() + ctx, _ = mock_fc mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -951,7 +916,8 @@ def test_run_stores_result_in_history( # First call is background init, then two actual test calls mock_asyncio_run.side_effect = [None, mock_result1, mock_result2] - shell = pyrit_shell.PyRITShell(context=mock_context) + shell = pyrit_shell.PyRITShell() + shell._init_thread.join(timeout=5) # Run two scenarios shell.do_run("scenario1 --initializers init1") From 09a452c10c24acbdbf8f1e5ab95509436f88c7ad Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 14:08:45 -0700 Subject: [PATCH 08/14] refactor: use explicit typed params instead of **context_kwargs Replace untyped **context_kwargs with explicit keyword-only parameters mirroring FrontendCore.__init__ (config_file, database, env_files, etc.). This provides strong typing at the call site without importing the heavy frontend_core module. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/pyrit_shell.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 2d4f7ee70..34416225c 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -65,7 +65,12 @@ def __init__( self, *, no_animation: bool = False, - **context_kwargs: Any, + config_file: Optional[Path] = None, + database: Optional[str] = None, + initialization_scripts: Optional[list[Path]] = None, + initializer_names: Optional[list[Any]] = None, + env_files: Optional[list[Path]] = None, + log_level: Optional[int] = None, ) -> None: """ Initialize the PyRIT shell. @@ -75,12 +80,28 @@ def __init__( shell prompt appears immediately. Args: - no_animation: If True, skip the animated startup banner. - **context_kwargs: Keyword arguments forwarded to ``FrontendCore()``. + no_animation (bool): If True, skip the animated startup banner. + config_file (Optional[Path]): Path to a YAML configuration file. + database (Optional[str]): Database type (InMemory, SQLite, or AzureSQL). + initialization_scripts (Optional[list[Path]]): Initialization script paths. + initializer_names (Optional[list[Any]]): Initializer entries (names or dicts). + env_files (Optional[list[Path]]): Environment file paths to load in order. + log_level (Optional[int]): Logging level constant (e.g., ``logging.WARNING``). """ super().__init__() self._no_animation = no_animation - self._context_kwargs = context_kwargs + self._context_kwargs: dict[str, Any] = { + k: v + for k, v in { + "config_file": config_file, + "database": database, + "initialization_scripts": initialization_scripts, + "initializer_names": initializer_names, + "env_files": env_files, + "log_level": log_level, + }.items() + if v is not None + } # Track scenario execution history: list of (command_string, ScenarioResult) tuples self._scenario_history: list[tuple[str, ScenarioResult]] = [] From eb0335c9804d8a052a4535f77b3fc039f5f1422d Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 14:26:15 -0700 Subject: [PATCH 09/14] perf: bypass background thread in most unit tests (25s -> 10s) Add a 'shell' fixture that creates an already-initialized PyRITShell without spawning a thread or running asyncio.run. Only the 4 init/thread-specific tests still use the real background thread via mock_fc. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/unit/cli/test_pyrit_shell.py | 398 ++++++++++++----------------- 1 file changed, 168 insertions(+), 230 deletions(-) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index a2ed17696..6b1100330 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -31,6 +31,33 @@ def mock_fc(): yield mock_context, mock_fc_class +@pytest.fixture() +def shell(): + """Create a fully-initialized PyRITShell without spawning a background thread. + + Bypasses the real ``_background_init`` and wires up a mock FrontendCore + directly, avoiding thread + asyncio.run overhead per test. + """ + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context._env_files = None + mock_context._scenario_registry = MagicMock() + mock_context._initializer_registry = MagicMock() + mock_context.initialize_async = AsyncMock() + + with patch("pyrit.cli.frontend_core.FrontendCore", return_value=mock_context) as mock_fc_class: + with patch.object(pyrit_shell.PyRITShell, "_background_init"): + s = pyrit_shell.PyRITShell() + # Manually set the state that _background_init would have set + s.context = mock_context + s.default_database = mock_context._database + s.default_log_level = mock_context._log_level + s.default_env_files = mock_context._env_files + s._init_complete.set() + yield s, mock_context, mock_fc_class + + class TestPyRITShell: """Tests for PyRITShell class.""" @@ -128,109 +155,91 @@ async def slow_init() -> None: init_release.set() shell._init_thread.join(timeout=2) - def test_prompt_and_intro(self, mock_fc): + def test_prompt_and_intro(self, shell): """Test shell prompt is set and cmdloop wires play_animation to intro.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - - assert shell.prompt == "pyrit> " + assert s.prompt == "pyrit> " # Verify that cmdloop calls play_animation and passes the result as intro with ( patch("pyrit.cli._banner.play_animation", return_value="TEST_BANNER") as mock_play, patch("cmd.Cmd.cmdloop") as mock_cmdloop, ): - shell.cmdloop() + s.cmdloop() - mock_play.assert_called_once_with(no_animation=shell._no_animation) + mock_play.assert_called_once_with(no_animation=s._no_animation) mock_cmdloop.assert_called_once_with(intro="TEST_BANNER") - def test_cmdloop_honors_explicit_intro(self, mock_fc): + def test_cmdloop_honors_explicit_intro(self, shell): """Test that cmdloop passes through a non-None intro without calling play_animation.""" - ctx, _ = mock_fc - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + s, ctx, _ = shell with patch("pyrit.cli._banner.play_animation") as mock_play, patch("cmd.Cmd.cmdloop") as mock_cmdloop: - shell.cmdloop(intro="Custom intro") + s.cmdloop(intro="Custom intro") mock_play.assert_not_called() mock_cmdloop.assert_called_once_with(intro="Custom intro") @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - def test_do_list_scenarios(self, mock_print_scenarios: AsyncMock, mock_fc): + def test_do_list_scenarios(self, mock_print_scenarios: AsyncMock, shell): """Test do_list_scenarios command.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_list_scenarios("") + s.do_list_scenarios("") mock_print_scenarios.assert_called_once_with(context=ctx) @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, mock_fc, capsys): + def test_do_list_scenarios_with_exception(self, mock_print_scenarios: AsyncMock, shell, capsys): """Test do_list_scenarios handles exceptions.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_print_scenarios.side_effect = ValueError("Test error") - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_list_scenarios("") + s.do_list_scenarios("") captured = capsys.readouterr() assert "Error listing scenarios" in captured.out @patch("pyrit.cli.frontend_core.get_default_initializer_discovery_path") @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - def test_do_list_initializers(self, mock_print_initializers: AsyncMock, mock_get_path: MagicMock, mock_fc): + def test_do_list_initializers(self, mock_print_initializers: AsyncMock, mock_get_path: MagicMock, shell): """Test do_list_initializers command.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_path = Path("/test/path") mock_get_path.return_value = mock_path - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_list_initializers("") + s.do_list_initializers("") mock_print_initializers.assert_called_once_with(context=ctx, discovery_path=mock_path) @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - def test_do_list_initializers_with_path(self, mock_print_initializers: AsyncMock, mock_fc): + def test_do_list_initializers_with_path(self, mock_print_initializers: AsyncMock, shell): """Test do_list_initializers with custom path.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_list_initializers("/custom/path") + s.do_list_initializers("/custom/path") assert mock_print_initializers.call_count == 1 call_kwargs = mock_print_initializers.call_args[1] assert isinstance(call_kwargs["discovery_path"], Path) @patch("pyrit.cli.frontend_core.print_initializers_list_async", new_callable=AsyncMock) - def test_do_list_initializers_with_exception(self, mock_print_initializers: AsyncMock, mock_fc, capsys): + def test_do_list_initializers_with_exception(self, mock_print_initializers: AsyncMock, shell, capsys): """Test do_list_initializers handles exceptions.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_print_initializers.side_effect = ValueError("Test error") - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_list_initializers("") + s.do_list_initializers("") captured = capsys.readouterr() assert "Error listing initializers" in captured.out - def test_do_run_empty_line(self, mock_fc, capsys): + def test_do_run_empty_line(self, shell, capsys): """Test do_run with empty line.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_run("") + s.do_run("") captured = capsys.readouterr() assert "Specify a scenario name" in captured.out @@ -243,10 +252,10 @@ def test_do_run_basic_scenario( mock_parse_args: MagicMock, _mock_run_scenario: AsyncMock, mock_asyncio_run: MagicMock, - mock_fc, + shell, ): """Test do_run with basic scenario.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -264,30 +273,25 @@ def test_do_run_basic_scenario( } mock_result = MagicMock() - # First call is background init, second call is the actual test - mock_asyncio_run.side_effect = [None, mock_result] + mock_asyncio_run.side_effect = [mock_result] - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_run("test_scenario --initializers test_init") + s.do_run("test_scenario --initializers test_init") mock_parse_args.assert_called_once() - assert mock_asyncio_run.call_count == 2 + assert mock_asyncio_run.call_count == 1 # Verify result was stored in history - assert len(shell._scenario_history) == 1 - assert shell._scenario_history[0][0] == "test_scenario --initializers test_init" - assert shell._scenario_history[0][1] == mock_result + assert len(s._scenario_history) == 1 + assert s._scenario_history[0][0] == "test_scenario --initializers test_init" + assert s._scenario_history[0][1] == mock_result @patch("pyrit.cli.frontend_core.parse_run_arguments") - def test_do_run_parse_error(self, mock_parse_args: MagicMock, mock_fc, capsys): + def test_do_run_parse_error(self, mock_parse_args: MagicMock, shell, capsys): """Test do_run with parse error.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_parse_args.side_effect = ValueError("Parse error") - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_run("test_scenario --invalid") + s.do_run("test_scenario --invalid") captured = capsys.readouterr() assert "Error: Parse error" in captured.out @@ -302,10 +306,10 @@ def test_do_run_with_initialization_scripts( mock_parse_args: MagicMock, mock_run_scenario: AsyncMock, mock_asyncio_run: MagicMock, - mock_fc, + shell, ): """Test do_run with initialization scripts.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -323,15 +327,12 @@ def test_do_run_with_initialization_scripts( } mock_resolve_scripts.return_value = [Path("/test/script.py")] - # First call is background init, second call is the actual test - mock_asyncio_run.side_effect = [None, MagicMock()] + mock_asyncio_run.side_effect = [MagicMock()] - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_run("test_scenario --initialization-scripts script.py") + s.do_run("test_scenario --initialization-scripts script.py") mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) - assert mock_asyncio_run.call_count == 2 + assert mock_asyncio_run.call_count == 1 @patch("pyrit.cli.frontend_core.parse_run_arguments") @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") @@ -339,11 +340,11 @@ def test_do_run_with_missing_script( self, mock_resolve_scripts: MagicMock, mock_parse_args: MagicMock, - mock_fc, + shell, capsys, ): """Test do_run with missing initialization script.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -362,9 +363,7 @@ def test_do_run_with_missing_script( mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_run("test_scenario --initialization-scripts missing.py") + s.do_run("test_scenario --initialization-scripts missing.py") captured = capsys.readouterr() assert "Error: Script not found" in captured.out @@ -375,10 +374,10 @@ def test_do_run_with_database_override( self, mock_parse_args: MagicMock, mock_asyncio_run: MagicMock, - mock_fc, + shell, ): """Test do_run with database override.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -395,14 +394,10 @@ def test_do_run_with_database_override( "max_dataset_size": None, } - # First call is background init, second call is the actual test - mock_asyncio_run.side_effect = [None, MagicMock()] - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + mock_asyncio_run.side_effect = [MagicMock()] with patch("pyrit.cli.frontend_core.FrontendCore") as mock_frontend: - shell.do_run("test_scenario --initializers test_init --database InMemory") + s.do_run("test_scenario --initializers test_init --database InMemory") # Verify FrontendCore was created with overridden database call_kwargs = mock_frontend.call_args[1] @@ -414,11 +409,11 @@ def test_do_run_with_exception( self, mock_parse_args: MagicMock, mock_asyncio_run: MagicMock, - mock_fc, + shell, capsys, ): """Test do_run handles exceptions during scenario run.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -435,39 +430,32 @@ def test_do_run_with_exception( "max_dataset_size": None, } - # First call succeeds (background init), second call raises error (the actual test) - mock_asyncio_run.side_effect = [None, ValueError("Test error")] + mock_asyncio_run.side_effect = [ValueError("Test error")] - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_run("test_scenario --initializers test_init") + s.do_run("test_scenario --initializers test_init") captured = capsys.readouterr() assert "Error: Test error" in captured.out - def test_do_scenario_history_empty(self, mock_fc, capsys): + def test_do_scenario_history_empty(self, shell, capsys): """Test do_scenario_history with no history.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_scenario_history("") + s.do_scenario_history("") captured = capsys.readouterr() assert "No scenario runs in history" in captured.out - def test_do_scenario_history_with_runs(self, mock_fc, capsys): + def test_do_scenario_history_with_runs(self, shell, capsys): """Test do_scenario_history with scenario runs.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell._scenario_history = [ + s._scenario_history = [ ("test_scenario1 --initializers init1", MagicMock()), ("test_scenario2 --initializers init2", MagicMock()), ] - shell.do_scenario_history("") + s.do_scenario_history("") captured = capsys.readouterr() assert "Scenario Run History" in captured.out @@ -475,13 +463,11 @@ def test_do_scenario_history_with_runs(self, mock_fc, capsys): assert "test_scenario2" in captured.out assert "Total runs: 2" in captured.out - def test_do_print_scenario_empty(self, mock_fc, capsys): + def test_do_print_scenario_empty(self, shell, capsys): """Test do_print_scenario with no history.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.do_print_scenario("") + s.do_print_scenario("") captured = capsys.readouterr() assert "No scenario runs in history" in captured.out @@ -492,27 +478,25 @@ def test_do_print_scenario_all( self, mock_printer_class: MagicMock, mock_asyncio_run: MagicMock, - mock_fc, + shell, capsys, ): """Test do_print_scenario without argument prints all.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_printer = MagicMock() mock_printer_class.return_value = mock_printer - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell._scenario_history = [ + s._scenario_history = [ ("test_scenario1", MagicMock()), ("test_scenario2", MagicMock()), ] - shell.do_print_scenario("") + s.do_print_scenario("") captured = capsys.readouterr() assert "Printing all scenario results" in captured.out - # 1 background init + 2 print calls - assert mock_asyncio_run.call_count == 3 + # 2 print calls (no background init) + assert mock_asyncio_run.call_count == 2 @patch("pyrit.cli.pyrit_shell.asyncio.run") @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") @@ -520,204 +504,166 @@ def test_do_print_scenario_specific( self, mock_printer_class: MagicMock, mock_asyncio_run: MagicMock, - mock_fc, + shell, capsys, ): """Test do_print_scenario with specific scenario number.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_printer = MagicMock() mock_printer_class.return_value = mock_printer - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell._scenario_history = [ + s._scenario_history = [ ("test_scenario1", MagicMock()), ("test_scenario2", MagicMock()), ] - shell.do_print_scenario("1") + s.do_print_scenario("1") captured = capsys.readouterr() assert "Scenario Run #1" in captured.out - # 1 background init + 1 print call - assert mock_asyncio_run.call_count == 2 + # 1 print call (no background init) + assert mock_asyncio_run.call_count == 1 - def test_do_print_scenario_invalid_number(self, mock_fc, capsys): + def test_do_print_scenario_invalid_number(self, shell, capsys): """Test do_print_scenario with invalid scenario number.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell._scenario_history = [ + s._scenario_history = [ ("test_scenario1", MagicMock()), ] - shell.do_print_scenario("5") + s.do_print_scenario("5") captured = capsys.readouterr() assert "must be between 1 and 1" in captured.out - def test_do_print_scenario_non_integer(self, mock_fc, capsys): + def test_do_print_scenario_non_integer(self, shell, capsys): """Test do_print_scenario with non-integer argument.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell._scenario_history = [ + s._scenario_history = [ ("test_scenario1", MagicMock()), ] - shell.do_print_scenario("invalid") + s.do_print_scenario("invalid") captured = capsys.readouterr() assert "Invalid scenario number" in captured.out - def test_do_help_without_arg(self, mock_fc, capsys): + def test_do_help_without_arg(self, shell, capsys): """Test do_help without argument.""" - ctx, _ = mock_fc - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + s, ctx, _ = shell # Capture help output with patch("cmd.Cmd.do_help"): - shell.do_help("") + s.do_help("") captured = capsys.readouterr() assert "Shell Startup Options" in captured.out - def test_do_help_with_arg(self, mock_fc): + def test_do_help_with_arg(self, shell): """Test do_help with specific command.""" - ctx, _ = mock_fc - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + s, ctx, _ = shell with patch("cmd.Cmd.do_help") as mock_parent_help: - shell.do_help("run") + s.do_help("run") mock_parent_help.assert_called_with("run") @patch.object(cmd.Cmd, "cmdloop") @patch.object(banner, "play_animation") - def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_cmdloop: MagicMock, mock_fc): + def test_cmdloop_sets_intro_via_play_animation(self, mock_play: MagicMock, mock_cmdloop: MagicMock, shell): """Test cmdloop wires banner.play_animation into intro and threads --no-animation.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_play.return_value = "animated banner" - shell = pyrit_shell.PyRITShell(no_animation=True) - shell._init_thread.join(timeout=5) - shell.cmdloop() + # Note: no_animation is not set because shell fixture uses default + s._no_animation = True + s.cmdloop() mock_play.assert_called_once_with(no_animation=True) - assert shell.intro == "animated banner" + assert s.intro == "animated banner" mock_cmdloop.assert_called_once_with(intro="animated banner") @patch.object(cmd.Cmd, "cmdloop") - def test_cmdloop_honors_explicit_intro(self, mock_cmdloop: MagicMock, mock_fc): + def test_cmdloop_honors_explicit_intro(self, mock_cmdloop: MagicMock, shell): """Test cmdloop honors a non-None intro argument without calling play_animation.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.cmdloop(intro="custom intro") + s.cmdloop(intro="custom intro") - assert shell.intro == "custom intro" + assert s.intro == "custom intro" mock_cmdloop.assert_called_once_with(intro="custom intro") - def test_do_exit(self, mock_fc, capsys): + def test_do_exit(self, shell, capsys): """Test do_exit command.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - result = shell.do_exit("") + result = s.do_exit("") assert result is True captured = capsys.readouterr() assert "Goodbye" in captured.out - def test_do_quit_alias(self, mock_fc): + def test_do_quit_alias(self, shell): """Test do_quit is alias for do_exit.""" - ctx, _ = mock_fc - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + s, ctx, _ = shell - assert shell.do_quit == shell.do_exit + assert s.do_quit == s.do_exit - def test_do_q_alias(self, mock_fc): + def test_do_q_alias(self, shell): """Test do_q is alias for do_exit.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + assert s.do_q == s.do_exit - assert shell.do_q == shell.do_exit - - def test_do_eof_alias(self, mock_fc): + def test_do_eof_alias(self, shell): """Test do_EOF is alias for do_exit.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - - assert shell.do_EOF == shell.do_exit + assert s.do_EOF == s.do_exit @patch("os.system") - def test_do_clear_windows(self, mock_system: MagicMock, mock_fc): + def test_do_clear_windows(self, mock_system: MagicMock, shell): """Test do_clear on Windows.""" - ctx, _ = mock_fc - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + s, ctx, _ = shell with patch("os.name", "nt"): - shell.do_clear("") + s.do_clear("") mock_system.assert_called_with("cls") @patch("os.system") - def test_do_clear_unix(self, mock_system: MagicMock, mock_fc): + def test_do_clear_unix(self, mock_system: MagicMock, shell): """Test do_clear on Unix.""" - ctx, _ = mock_fc - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + s, ctx, _ = shell with patch("os.name", "posix"): - shell.do_clear("") + s.do_clear("") mock_system.assert_called_with("clear") - def test_emptyline(self, mock_fc): + def test_emptyline(self, shell): """Test emptyline doesn't repeat last command.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - result = shell.emptyline() + result = s.emptyline() assert result is False - def test_default_with_hyphen_to_underscore(self, mock_fc): + def test_default_with_hyphen_to_underscore(self, shell): """Test default converts hyphens to underscores.""" - ctx, _ = mock_fc - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + s, ctx, _ = shell # Mock a method with underscores - shell.do_list_scenarios = MagicMock() + s.do_list_scenarios = MagicMock() - shell.default("list-scenarios") + s.default("list-scenarios") - shell.do_list_scenarios.assert_called_once_with("") + s.do_list_scenarios.assert_called_once_with("") - def test_default_unknown_command(self, mock_fc, capsys): + def test_default_unknown_command(self, shell, capsys): """Test default with unknown command.""" - ctx, _ = mock_fc + s, ctx, _ = shell - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) - shell.default("unknown_command") + s.default("unknown_command") captured = capsys.readouterr() assert "Unknown command" in captured.out @@ -852,10 +798,10 @@ def test_run_with_all_parameters( self, mock_parse_args: MagicMock, mock_asyncio_run: MagicMock, - mock_fc, + shell, ): """Test run command with all parameters.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -872,18 +818,14 @@ def test_run_with_all_parameters( "max_dataset_size": None, } - # First call is background init, second call is the actual test - mock_asyncio_run.side_effect = [None, MagicMock()] - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + mock_asyncio_run.side_effect = [MagicMock()] with patch("pyrit.cli.frontend_core.FrontendCore"), patch("pyrit.cli.frontend_core.run_scenario_async"): - shell.do_run("test_scenario --initializers init1 --strategies s1 s2 --max-concurrency 10") + s.do_run("test_scenario --initializers init1 --strategies s1 s2 --max-concurrency 10") # Verify run_scenario_async was called with correct args # (it's called via asyncio.run, so check the mock_asyncio_run call) - assert mock_asyncio_run.call_count == 2 + assert mock_asyncio_run.call_count == 1 @patch("pyrit.cli.pyrit_shell.asyncio.run") @patch("pyrit.cli.frontend_core.parse_run_arguments") @@ -891,10 +833,10 @@ def test_run_stores_result_in_history( self, mock_parse_args: MagicMock, mock_asyncio_run: MagicMock, - mock_fc, + shell, ): """Test run command stores result in history.""" - ctx, _ = mock_fc + s, ctx, _ = shell mock_parse_args.return_value = { "scenario_name": "test_scenario", @@ -913,17 +855,13 @@ def test_run_stores_result_in_history( mock_result1 = MagicMock() mock_result2 = MagicMock() - # First call is background init, then two actual test calls - mock_asyncio_run.side_effect = [None, mock_result1, mock_result2] - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=5) + mock_asyncio_run.side_effect = [mock_result1, mock_result2] # Run two scenarios - shell.do_run("scenario1 --initializers init1") - shell.do_run("scenario2 --initializers init2") + s.do_run("scenario1 --initializers init1") + s.do_run("scenario2 --initializers init2") # Verify both are in history - assert len(shell._scenario_history) == 2 - assert shell._scenario_history[0][1] == mock_result1 - assert shell._scenario_history[1][1] == mock_result2 + assert len(s._scenario_history) == 2 + assert s._scenario_history[0][1] == mock_result1 + assert s._scenario_history[1][1] == mock_result2 From 0ac0871e26bd3473cf320149beaab24f728f1821 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 14:34:14 -0700 Subject: [PATCH 10/14] feat: add deprecated context param with validation Add back context as an optional deprecated parameter to PyRITShell. Emits a DeprecationWarning via print_deprecation_message (removed in 0.14.0). Raises ValueError if context is provided together with any FrontendCore keyword arguments. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/pyrit_shell.py | 31 ++++++++++++++++++++++++++++-- tests/unit/cli/test_pyrit_shell.py | 17 ++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 34416225c..d674d5617 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -71,6 +71,7 @@ def __init__( initializer_names: Optional[list[Any]] = None, env_files: Optional[list[Path]] = None, log_level: Optional[int] = None, + context: Optional[frontend_core.FrontendCore] = None, ) -> None: """ Initialize the PyRIT shell. @@ -87,6 +88,12 @@ def __init__( initializer_names (Optional[list[Any]]): Initializer entries (names or dicts). env_files (Optional[list[Path]]): Environment file paths to load in order. log_level (Optional[int]): Logging level constant (e.g., ``logging.WARNING``). + context (Optional[frontend_core.FrontendCore]): Deprecated. Pre-created FrontendCore + context. Use the individual keyword arguments instead. + + Raises: + ValueError: If ``context`` is provided together with any other + FrontendCore keyword arguments. """ super().__init__() self._no_animation = no_animation @@ -103,6 +110,23 @@ def __init__( if v is not None } + if context is not None: + if self._context_kwargs: + raise ValueError( + "Cannot pass 'context' together with FrontendCore keyword arguments " + f"({', '.join(self._context_kwargs)}). Use one or the other." + ) + from pyrit.common.deprecation import print_deprecation_message + + print_deprecation_message( + old_item="PyRITShell(context=...)", + new_item="PyRITShell(database=..., log_level=..., ...)", + removed_in="0.14.0", + ) + self._deprecated_context = context + else: + self._deprecated_context = None + # Track scenario execution history: list of (command_string, ScenarioResult) tuples self._scenario_history: list[tuple[str, ScenarioResult]] = [] @@ -121,9 +145,12 @@ def __init__( def _background_init(self) -> None: """Import heavy modules and initialize PyRIT in the background.""" try: - from pyrit.cli import frontend_core as fc + if self._deprecated_context is not None: + self.context = self._deprecated_context + else: + from pyrit.cli import frontend_core as fc - self.context = fc.FrontendCore(**self._context_kwargs) + self.context = fc.FrontendCore(**self._context_kwargs) self.default_database = self.context._database self.default_log_level = self.context._log_level self.default_env_files = self.context._env_files diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 6b1100330..fcd4f9f31 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -155,6 +155,23 @@ async def slow_init() -> None: init_release.set() shell._init_thread.join(timeout=2) + def test_deprecated_context_param_emits_warning(self, mock_fc): + """Test that passing context= emits a DeprecationWarning and uses the provided context.""" + ctx, _ = mock_fc + + with pytest.warns(DeprecationWarning, match="context"): + shell = pyrit_shell.PyRITShell(context=ctx) + shell._init_thread.join(timeout=5) + + assert shell.context is ctx + + def test_context_with_kwargs_raises_value_error(self, mock_fc): + """Test that passing both context and FrontendCore kwargs raises ValueError.""" + ctx, _ = mock_fc + + with pytest.raises(ValueError, match="Cannot pass 'context' together with"): + pyrit_shell.PyRITShell(context=ctx, database="InMemory") + def test_prompt_and_intro(self, shell): """Test shell prompt is set and cmdloop wires play_animation to intro.""" s, ctx, _ = shell From d90f5185cff6825b53862b2967302d99ab93ec83 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 14:41:49 -0700 Subject: [PATCH 11/14] test: remove timing-sensitive tests from unit suite Remove test_cmdloop_does_not_hang_when_background_init_fails and test_cmdloop_does_not_block_on_slow_init from unit tests. These thread-timing tests are flaky on CI; startup performance is already covered by the integration test. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/unit/cli/test_pyrit_shell.py | 68 ------------------------------ 1 file changed, 68 deletions(-) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index fcd4f9f31..1773278a4 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -6,7 +6,6 @@ """ import cmd -import threading from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -88,73 +87,6 @@ def test_background_init_failure_sets_event_and_raises_in_ensure_initialized(sel with pytest.raises(RuntimeError, match="Initialization failed"): shell._ensure_initialized() - def test_cmdloop_does_not_hang_when_background_init_fails(self, mock_fc): - """Test cmdloop surfaces background initialization failures instead of waiting forever.""" - ctx, _ = mock_fc - ctx.initialize_async = AsyncMock(side_effect=RuntimeError("Initialization failed")) - - shell = pyrit_shell.PyRITShell() - shell._init_thread.join(timeout=2) - - errors: list[BaseException] = [] - - def run_cmdloop() -> None: - try: - shell.cmdloop() - except BaseException as exc: # pragma: no cover - assertion target - errors.append(exc) - - with ( - patch("pyrit.cli._banner.play_animation") as mock_play, - patch("cmd.Cmd.cmdloop") as mock_cmdloop, - ): - cmdloop_thread = threading.Thread(target=run_cmdloop, daemon=True) - cmdloop_thread.start() - cmdloop_thread.join(timeout=5) - - assert not cmdloop_thread.is_alive() - assert len(errors) == 1 - assert isinstance(errors[0], RuntimeError) - assert str(errors[0]) == "Initialization failed" - # Animation plays first (no blocking wait), then error is surfaced - mock_play.assert_called_once() - mock_cmdloop.assert_not_called() - - def test_cmdloop_does_not_block_on_slow_init(self, mock_fc): - """Test that cmdloop plays the animation immediately without waiting for initialization.""" - ctx, _ = mock_fc - init_started = threading.Event() - init_release = threading.Event() - - async def slow_init() -> None: - init_started.set() - # Block until the test explicitly releases us - init_release.wait() - - ctx.initialize_async = slow_init - - shell = pyrit_shell.PyRITShell() - # Wait for background init to actually start running - init_started.wait(timeout=2) - - with ( - patch("pyrit.cli._banner.play_animation", return_value="BANNER") as mock_play, - patch("cmd.Cmd.cmdloop") as mock_cmdloop, - ): - cmdloop_thread = threading.Thread(target=shell.cmdloop, daemon=True) - cmdloop_thread.start() - cmdloop_thread.join(timeout=2) - - # cmdloop should have completed even though init is still running - assert not cmdloop_thread.is_alive() - mock_play.assert_called_once() - mock_cmdloop.assert_called_once() - assert not shell._init_complete.is_set() - - # Clean up: let the background init finish - init_release.set() - shell._init_thread.join(timeout=2) - def test_deprecated_context_param_emits_warning(self, mock_fc): """Test that passing context= emits a DeprecationWarning and uses the provided context.""" ctx, _ = mock_fc From 267bd849c671b2870eaa289df587b00641e26977 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 17 Mar 2026 15:43:40 -0700 Subject: [PATCH 12/14] refactor: extract shared CLI args into lightweight _cli_args module Move constants, ARG_HELP, validators, argparse wrappers, and parsers from frontend_core.py into _cli_args.py (stdlib-only imports). Both pyrit_shell and pyrit_scan can now reference help text and constants without the heavy frontend_core import. frontend_core re-exports everything for backward compatibility. Also replaces per-method lazy frontend_core imports in PyRITShell with a single self._fc reference set during background init. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/_cli_args.py | 495 +++++++++++++++++++++++++++++ pyrit/cli/frontend_core.py | 431 ++----------------------- pyrit/cli/pyrit_shell.py | 73 ++--- tests/unit/cli/test_pyrit_shell.py | 3 + 4 files changed, 548 insertions(+), 454 deletions(-) create mode 100644 pyrit/cli/_cli_args.py diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py new file mode 100644 index 000000000..29d1810db --- /dev/null +++ b/pyrit/cli/_cli_args.py @@ -0,0 +1,495 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Lightweight shared CLI argument definitions for PyRIT frontends. + +This module contains constants, validators, help text, and argument parsers +that are shared between ``pyrit_shell``, ``pyrit_scan``, and other CLI entry +points. It intentionally avoids heavy imports (no ``pyrit.scenario``, +``pyrit.registry``, ``pyrit.setup``, etc.) so it can be loaded quickly for +argument parsing before the full runtime is initialised. +""" + +from __future__ import annotations + +import argparse +import inspect +import json +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from collections.abc import Callable + +# --------------------------------------------------------------------------- +# Database type constants +# --------------------------------------------------------------------------- +IN_MEMORY = "InMemory" +SQLITE = "SQLite" +AZURE_SQL = "AzureSQL" + + +# --------------------------------------------------------------------------- +# Pure validators +# --------------------------------------------------------------------------- + + +def validate_database(*, database: str) -> str: + """ + Validate database type. + + Args: + database: Database type string. + + Returns: + Validated database type. + + Raises: + ValueError: If database type is invalid. + """ + valid_databases = [IN_MEMORY, SQLITE, AZURE_SQL] + if database not in valid_databases: + raise ValueError(f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") + return database + + +def validate_log_level(*, log_level: str) -> int: + """ + Validate log level and convert to logging constant. + + Args: + log_level: Log level string (case-insensitive). + + Returns: + Validated log level as logging constant (e.g., logging.WARNING). + + Raises: + ValueError: If log level is invalid. + """ + valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + level_upper = log_level.upper() + if level_upper not in valid_levels: + raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") + level_value: int = getattr(logging, level_upper) + return level_value + + +def validate_integer(value: str, *, name: str = "value", min_value: Optional[int] = None) -> int: + """ + Validate and parse an integer value. + + Note: The 'value' parameter is positional (not keyword-only) to allow use with + argparse lambdas like: lambda v: validate_integer(v, min_value=1). + This is an exception to the PyRIT style guide for argparse compatibility. + + Args: + value: String value to parse. + name: Parameter name for error messages. Defaults to "value". + min_value: Optional minimum value constraint. + + Returns: + Parsed integer. + + Raises: + ValueError: If value is not a valid integer or violates constraints. + """ + # Reject boolean types explicitly (int(True) == 1, int(False) == 0) + if isinstance(value, bool): + raise ValueError(f"{name} must be an integer string, got boolean: {value}") + + # Ensure value is a string + if not isinstance(value, str): + raise ValueError(f"{name} must be a string, got {type(value).__name__}: {value}") + + # Strip whitespace and validate it looks like an integer + value = value.strip() + if not value: + raise ValueError(f"{name} cannot be empty") + + try: + int_value = int(value) + except (ValueError, TypeError) as e: + raise ValueError(f"{name} must be an integer, got: {value}") from e + + if min_value is not None and int_value < min_value: + raise ValueError(f"{name} must be at least {min_value}, got: {int_value}") + + return int_value + + +# --------------------------------------------------------------------------- +# Argparse adapter +# --------------------------------------------------------------------------- + + +def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], Any]: + """ + Adapt a validator to argparse by converting ValueError to ArgumentTypeError. + + This decorator adapts our keyword-only validators for use with argparse's type= parameter. + It handles two challenges: + + 1. Exception Translation: argparse expects ArgumentTypeError, but our validators raise + ValueError. This decorator catches ValueError and re-raises as ArgumentTypeError. + + 2. Keyword-Only Parameters: PyRIT validators use keyword-only parameters (e.g., + validate_database(*, database: str)), but argparse's type= passes a positional argument. + This decorator inspects the function signature and calls the validator with the correct + keyword argument name. + + This pattern allows us to: + - Keep validators as pure functions with proper type hints + - Follow PyRIT style guide (keyword-only parameters) + - Reuse the same validation logic in both argparse and non-argparse contexts + + Args: + validator_func: Function that raises ValueError on invalid input. + Must have at least one parameter (can be keyword-only). + + Returns: + Wrapped function that: + - Accepts a single positional argument (for argparse compatibility) + - Calls validator_func with the correct keyword argument + - Raises ArgumentTypeError instead of ValueError + + Raises: + ValueError: If validator_func has no parameters. + """ + # Get the first parameter name from the function signature + sig = inspect.signature(validator_func) + params = list(sig.parameters.keys()) + if not params: + raise ValueError(f"Validator function {validator_func.__name__} must have at least one parameter") + first_param = params[0] + + def wrapper(value: Any) -> Any: + try: + # Call with keyword argument to support keyword-only parameters + return validator_func(**{first_param: value}) + except ValueError as e: + raise argparse.ArgumentTypeError(str(e)) from e + + # Preserve function metadata for better debugging + wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator") + wrapper.__doc__ = getattr(validator_func, "__doc__", None) + return wrapper + + +# --------------------------------------------------------------------------- +# Path / env-file helpers +# --------------------------------------------------------------------------- + + +def resolve_env_files(*, env_file_paths: list[str]) -> list[Path]: + """ + Resolve environment file paths to absolute Path objects. + + Args: + env_file_paths: List of environment file path strings. + + Returns: + List of resolved Path objects. + + Raises: + ValueError: If any path does not exist. + """ + resolved_paths = [] + for path_str in env_file_paths: + path = Path(path_str).resolve() + if not path.exists(): + raise ValueError(f"Environment file not found: {path}") + resolved_paths.append(path) + return resolved_paths + + +# --------------------------------------------------------------------------- +# Argparse-compatible validators +# +# These wrappers adapt our core validators (which use keyword-only parameters and raise +# ValueError) for use with argparse's type= parameter (which passes positional arguments +# and expects ArgumentTypeError). +# +# Pattern: +# - Use core validators (validate_database, validate_log_level, etc.) in regular code +# - Use these _argparse versions ONLY in parser.add_argument(..., type=...) +# +# The lambda wrappers for validate_integer are necessary because we need to partially +# apply the min_value parameter while still allowing the decorator to work correctly. +# --------------------------------------------------------------------------- +validate_database_argparse = _argparse_validator(validate_database) +validate_log_level_argparse = _argparse_validator(validate_log_level) +positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1)) +non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0)) +resolve_env_files_argparse = _argparse_validator(resolve_env_files) + + +# --------------------------------------------------------------------------- +# Memory label / argument parsing +# --------------------------------------------------------------------------- + + +def parse_memory_labels(json_string: str) -> dict[str, str]: + """ + Parse memory labels from a JSON string. + + Args: + json_string: JSON string containing label key-value pairs. + + Returns: + Dictionary of labels. + + Raises: + ValueError: If JSON is invalid or contains non-string values. + """ + try: + labels = json.loads(json_string) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON for memory labels: {e}") from e + + if not isinstance(labels, dict): + raise ValueError("Memory labels must be a JSON object (dictionary)") + + # Validate all keys and values are strings + for key, value in labels.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise ValueError(f"All label keys and values must be strings. Got: {key}={value}") + + return labels + + +# --------------------------------------------------------------------------- +# Shared argument help text +# --------------------------------------------------------------------------- +ARG_HELP = { + "config_file": ( + "Path to a YAML configuration file. Allows specifying database, initializers (with args), " + "initialization scripts, and env files. CLI arguments override config file values. " + "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." + ), + "initializers": ( + "Built-in initializer names to run before the scenario. " + "Supports optional params with name:key=val syntax " + "(e.g., target:tags=default,scorer dataset:mode=strict)" + ), + "initialization_scripts": "Paths to custom Python initialization scripts to run before the scenario", + "env_files": "Paths to environment files to load in order (e.g., .env.production .env.local). Later files " + "override earlier ones.", + "scenario_strategies": "List of strategy names to run (e.g., base64 rot13)", + "max_concurrency": "Maximum number of concurrent attack executions (must be >= 1)", + "max_retries": "Maximum number of automatic retries on exception (must be >= 0)", + "memory_labels": 'Additional labels as JSON string (e.g., \'{"experiment": "test1"}\')', + "database": "Database type to use for memory storage", + "log_level": "Logging level", + "dataset_names": "List of dataset names to use instead of scenario defaults (e.g., harmbench advbench). " + "Creates a new dataset config; fetches all items unless --max-dataset-size is also specified", + "max_dataset_size": "Maximum number of items to use from the dataset (must be >= 1). " + "Limits new datasets if --dataset-names provided, otherwise overrides scenario's default limit", +} + + +# --------------------------------------------------------------------------- +# Initializer argument parsing +# --------------------------------------------------------------------------- + + +def _parse_initializer_arg(arg: str) -> str | dict[str, Any]: + """ + Parse an initializer CLI argument into a string or dict for ConfigurationLoader. + + Supports two formats: + - Simple name: "simple" → "simple" + - Name with params: "target:tags=default,scorer" → {"name": "target", "args": {"tags": ["default", "scorer"]}} + + For multiple params on one initializer, separate with semicolons: "name:key1=val1;key2=val2" + For multiple initializers with params, space-separate them: "target:tags=a,b dataset:mode=strict" + + Args: + arg: The CLI argument string. + + Returns: + str | dict[str, Any]: A plain name string, or a dict with 'name' and 'args' keys. + + Raises: + ValueError: If the argument format is invalid. + """ + if ":" not in arg: + return arg + + name, params_str = arg.split(":", 1) + if not name: + raise ValueError(f"Invalid initializer argument '{arg}': missing name before ':'") + + args: dict[str, list[str]] = {} + for pair in params_str.split(";"): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + raise ValueError(f"Invalid initializer parameter '{pair}' in '{arg}': expected key=value format") + key, value = pair.split("=", 1) + key = key.strip() + if not key: + raise ValueError(f"Invalid initializer parameter in '{arg}': empty key") + args[key] = [v.strip() for v in value.split(",")] + + if args: + return {"name": name, "args": args} + return name + + +def parse_run_arguments(*, args_string: str) -> dict[str, Any]: + """ + Parse run command arguments from a string (for shell mode). + + Args: + args_string: Space-separated argument string (e.g., "scenario_name --initializers foo --strategies bar"). + + Returns: + Dictionary with parsed arguments: + - scenario_name: str + - initializers: Optional[list[str]] + - initialization_scripts: Optional[list[str]] + - scenario_strategies: Optional[list[str]] + - max_concurrency: Optional[int] + - max_retries: Optional[int] + - memory_labels: Optional[dict[str, str]] + - database: Optional[str] + - log_level: Optional[int] + - dataset_names: Optional[list[str]] + - max_dataset_size: Optional[int] + + Raises: + ValueError: If parsing or validation fails. + """ + parts = args_string.split() + + if not parts: + raise ValueError("No scenario name provided") + + result: dict[str, Any] = { + "scenario_name": parts[0], + "initializers": None, + "initialization_scripts": None, + "env_files": None, + "scenario_strategies": None, + "max_concurrency": None, + "max_retries": None, + "memory_labels": None, + "database": None, + "log_level": None, + "dataset_names": None, + "max_dataset_size": None, + } + + i = 1 + while i < len(parts): + if parts[i] == "--initializers": + # Collect initializers until next flag, parsing name:key=val syntax + result["initializers"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initializers"].append(_parse_initializer_arg(parts[i])) + i += 1 + elif parts[i] == "--initialization-scripts": + # Collect script paths until next flag + result["initialization_scripts"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initialization_scripts"].append(parts[i]) + i += 1 + elif parts[i] == "--env-files": + # Collect env file paths until next flag + result["env_files"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["env_files"].append(parts[i]) + i += 1 + elif parts[i] in ("--strategies", "-s"): + # Collect strategies until next flag + result["scenario_strategies"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--") and parts[i] != "-s": + result["scenario_strategies"].append(parts[i]) + i += 1 + elif parts[i] == "--max-concurrency": + i += 1 + if i >= len(parts): + raise ValueError("--max-concurrency requires a value") + result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1) + i += 1 + elif parts[i] == "--max-retries": + i += 1 + if i >= len(parts): + raise ValueError("--max-retries requires a value") + result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0) + i += 1 + elif parts[i] == "--memory-labels": + i += 1 + if i >= len(parts): + raise ValueError("--memory-labels requires a value") + result["memory_labels"] = parse_memory_labels(parts[i]) + i += 1 + elif parts[i] == "--database": + i += 1 + if i >= len(parts): + raise ValueError("--database requires a value") + result["database"] = validate_database(database=parts[i]) + i += 1 + elif parts[i] == "--log-level": + i += 1 + if i >= len(parts): + raise ValueError("--log-level requires a value") + result["log_level"] = validate_log_level(log_level=parts[i]) + i += 1 + elif parts[i] == "--dataset-names": + # Collect dataset names until next flag + result["dataset_names"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["dataset_names"].append(parts[i]) + i += 1 + elif parts[i] == "--max-dataset-size": + i += 1 + if i >= len(parts): + raise ValueError("--max-dataset-size requires a value") + result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) + i += 1 + else: + _logger.warning(f"Unknown argument: {parts[i]}") + i += 1 + + return result + + +# --------------------------------------------------------------------------- +# Shared argparse builder +# --------------------------------------------------------------------------- + + +def add_common_arguments(parser: argparse.ArgumentParser) -> None: + """Add arguments shared between pyrit_shell and pyrit_scan.""" + parser.add_argument("--config-file", type=Path, help=ARG_HELP["config_file"]) + parser.add_argument( + "--database", + type=validate_database_argparse, + default=None, + help=f"Database type to use ({IN_MEMORY}, {SQLITE}, {AZURE_SQL}). Defaults to config or {SQLITE}.", + ) + parser.add_argument( + "--log-level", + type=validate_log_level_argparse, + default=logging.WARNING, + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)", + ) + parser.add_argument( + "--env-files", + type=str, + nargs="+", + help=ARG_HELP["env_files"], + ) + + +# Module-level logger (stdlib only — no heavy deps) +_logger = logging.getLogger(__name__) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 5ff336c18..ff4a6321c 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -15,14 +15,29 @@ from __future__ import annotations -import argparse -import inspect -import json import logging import sys from pathlib import Path from typing import TYPE_CHECKING, Any, Optional +from pyrit.cli._cli_args import ARG_HELP as ARG_HELP +from pyrit.cli._cli_args import AZURE_SQL as AZURE_SQL +from pyrit.cli._cli_args import IN_MEMORY as IN_MEMORY +from pyrit.cli._cli_args import SQLITE as SQLITE +from pyrit.cli._cli_args import _argparse_validator as _argparse_validator +from pyrit.cli._cli_args import _parse_initializer_arg as _parse_initializer_arg +from pyrit.cli._cli_args import add_common_arguments as add_common_arguments +from pyrit.cli._cli_args import non_negative_int as non_negative_int +from pyrit.cli._cli_args import parse_memory_labels as parse_memory_labels +from pyrit.cli._cli_args import parse_run_arguments as parse_run_arguments +from pyrit.cli._cli_args import positive_int as positive_int +from pyrit.cli._cli_args import resolve_env_files as resolve_env_files +from pyrit.cli._cli_args import resolve_env_files_argparse as resolve_env_files_argparse +from pyrit.cli._cli_args import validate_database as validate_database +from pyrit.cli._cli_args import validate_database_argparse as validate_database_argparse +from pyrit.cli._cli_args import validate_integer as validate_integer +from pyrit.cli._cli_args import validate_log_level as validate_log_level +from pyrit.cli._cli_args import validate_log_level_argparse as validate_log_level_argparse from pyrit.registry import InitializerRegistry, ScenarioRegistry from pyrit.scenario import DatasetConfiguration from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter @@ -47,7 +62,7 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Sequence from pyrit.models.scenario_result import ScenarioResult from pyrit.registry import ( @@ -57,11 +72,6 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i logger = logging.getLogger(__name__) -# Database type constants -IN_MEMORY = "InMemory" -SQLITE = "SQLite" -AZURE_SQL = "AzureSQL" - class FrontendCore: """ @@ -494,142 +504,6 @@ def format_initializer_metadata(*, initializer_metadata: InitializerMetadata) -> print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" ")) -def validate_database(*, database: str) -> str: - """ - Validate database type. - - Args: - database: Database type string. - - Returns: - Validated database type. - - Raises: - ValueError: If database type is invalid. - """ - valid_databases = [IN_MEMORY, SQLITE, AZURE_SQL] - if database not in valid_databases: - raise ValueError(f"Invalid database type: {database}. Must be one of: {', '.join(valid_databases)}") - return database - - -def validate_log_level(*, log_level: str) -> int: - """ - Validate log level and convert to logging constant. - - Args: - log_level: Log level string (case-insensitive). - - Returns: - Validated log level as logging constant (e.g., logging.WARNING). - - Raises: - ValueError: If log level is invalid. - """ - valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - level_upper = log_level.upper() - if level_upper not in valid_levels: - raise ValueError(f"Invalid log level: {log_level}. Must be one of: {', '.join(valid_levels)}") - level_value: int = getattr(logging, level_upper) - return level_value - - -def validate_integer(value: str, *, name: str = "value", min_value: Optional[int] = None) -> int: - """ - Validate and parse an integer value. - - Note: The 'value' parameter is positional (not keyword-only) to allow use with - argparse lambdas like: lambda v: validate_integer(v, min_value=1). - This is an exception to the PyRIT style guide for argparse compatibility. - - Args: - value: String value to parse. - name: Parameter name for error messages. Defaults to "value". - min_value: Optional minimum value constraint. - - Returns: - Parsed integer. - - Raises: - ValueError: If value is not a valid integer or violates constraints. - """ - # Reject boolean types explicitly (int(True) == 1, int(False) == 0) - if isinstance(value, bool): - raise ValueError(f"{name} must be an integer string, got boolean: {value}") - - # Ensure value is a string - if not isinstance(value, str): - raise ValueError(f"{name} must be a string, got {type(value).__name__}: {value}") - - # Strip whitespace and validate it looks like an integer - value = value.strip() - if not value: - raise ValueError(f"{name} cannot be empty") - - try: - int_value = int(value) - except (ValueError, TypeError) as e: - raise ValueError(f"{name} must be an integer, got: {value}") from e - - if min_value is not None and int_value < min_value: - raise ValueError(f"{name} must be at least {min_value}, got: {int_value}") - - return int_value - - -def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], Any]: - """ - Adapt a validator to argparse by converting ValueError to ArgumentTypeError. - - This decorator adapts our keyword-only validators for use with argparse's type= parameter. - It handles two challenges: - - 1. Exception Translation: argparse expects ArgumentTypeError, but our validators raise - ValueError. This decorator catches ValueError and re-raises as ArgumentTypeError. - - 2. Keyword-Only Parameters: PyRIT validators use keyword-only parameters (e.g., - validate_database(*, database: str)), but argparse's type= passes a positional argument. - This decorator inspects the function signature and calls the validator with the correct - keyword argument name. - - This pattern allows us to: - - Keep validators as pure functions with proper type hints - - Follow PyRIT style guide (keyword-only parameters) - - Reuse the same validation logic in both argparse and non-argparse contexts - - Args: - validator_func: Function that raises ValueError on invalid input. - Must have at least one parameter (can be keyword-only). - - Returns: - Wrapped function that: - - Accepts a single positional argument (for argparse compatibility) - - Calls validator_func with the correct keyword argument - - Raises ArgumentTypeError instead of ValueError - - Raises: - ValueError: If validator_func has no parameters. - """ - # Get the first parameter name from the function signature - sig = inspect.signature(validator_func) - params = list(sig.parameters.keys()) - if not params: - raise ValueError(f"Validator function {validator_func.__name__} must have at least one parameter") - first_param = params[0] - - def wrapper(value: Any) -> Any: - try: - # Call with keyword argument to support keyword-only parameters - return validator_func(**{first_param: value}) - except ValueError as e: - raise argparse.ArgumentTypeError(str(e)) from e - - # Preserve function metadata for better debugging - wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator") - wrapper.__doc__ = getattr(validator_func, "__doc__", None) - return wrapper - - def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]: """ Resolve initialization script paths. @@ -646,76 +520,6 @@ def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]: return InitializerRegistry.resolve_script_paths(script_paths=script_paths) -def resolve_env_files(*, env_file_paths: list[str]) -> list[Path]: - """ - Resolve environment file paths to absolute Path objects. - - Args: - env_file_paths: List of environment file path strings. - - Returns: - List of resolved Path objects. - - Raises: - ValueError: If any path does not exist. - """ - resolved_paths = [] - for path_str in env_file_paths: - path = Path(path_str).resolve() - if not path.exists(): - raise ValueError(f"Environment file not found: {path}") - resolved_paths.append(path) - return resolved_paths - - -# Argparse-compatible validators -# -# These wrappers adapt our core validators (which use keyword-only parameters and raise -# ValueError) for use with argparse's type= parameter (which passes positional arguments -# and expects ArgumentTypeError). -# -# Pattern: -# - Use core validators (validate_database, validate_log_level, etc.) in regular code -# - Use these _argparse versions ONLY in parser.add_argument(..., type=...) -# -# The lambda wrappers for validate_integer are necessary because we need to partially -# apply the min_value parameter while still allowing the decorator to work correctly. -validate_database_argparse = _argparse_validator(validate_database) -validate_log_level_argparse = _argparse_validator(validate_log_level) -positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1)) -non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0)) -resolve_env_files_argparse = _argparse_validator(resolve_env_files) - - -def parse_memory_labels(json_string: str) -> dict[str, str]: - """ - Parse memory labels from a JSON string. - - Args: - json_string: JSON string containing label key-value pairs. - - Returns: - Dictionary of labels. - - Raises: - ValueError: If JSON is invalid or contains non-string values. - """ - try: - labels = json.loads(json_string) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON for memory labels: {e}") from e - - if not isinstance(labels, dict): - raise ValueError("Memory labels must be a JSON object (dictionary)") - - # Validate all keys and values are strings - for key, value in labels.items(): - if not isinstance(key, str) or not isinstance(value, str): - raise ValueError(f"All label keys and values must be strings. Got: {key}={value}") - - return labels - - def get_default_initializer_discovery_path() -> Path: """ Get the default path for discovering initializers. @@ -776,200 +580,3 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path print("\n" + "=" * 80) print(f"\nTotal initializers: {len(initializers)}") return 0 - - -# Shared argument help text -ARG_HELP = { - "config_file": ( - "Path to a YAML configuration file. Allows specifying database, initializers (with args), " - "initialization scripts, and env files. CLI arguments override config file values. " - "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." - ), - "initializers": ( - "Built-in initializer names to run before the scenario. " - "Supports optional params with name:key=val syntax " - "(e.g., target:tags=default,scorer dataset:mode=strict)" - ), - "initialization_scripts": "Paths to custom Python initialization scripts to run before the scenario", - "env_files": "Paths to environment files to load in order (e.g., .env.production .env.local). Later files " - "override earlier ones.", - "scenario_strategies": "List of strategy names to run (e.g., base64 rot13)", - "max_concurrency": "Maximum number of concurrent attack executions (must be >= 1)", - "max_retries": "Maximum number of automatic retries on exception (must be >= 0)", - "memory_labels": 'Additional labels as JSON string (e.g., \'{"experiment": "test1"}\')', - "database": "Database type to use for memory storage", - "log_level": "Logging level", - "dataset_names": "List of dataset names to use instead of scenario defaults (e.g., harmbench advbench). " - "Creates a new dataset config; fetches all items unless --max-dataset-size is also specified", - "max_dataset_size": "Maximum number of items to use from the dataset (must be >= 1). " - "Limits new datasets if --dataset-names provided, otherwise overrides scenario's default limit", -} - - -def _parse_initializer_arg(arg: str) -> str | dict[str, Any]: - """ - Parse an initializer CLI argument into a string or dict for ConfigurationLoader. - - Supports two formats: - - Simple name: "simple" → "simple" - - Name with params: "target:tags=default,scorer" → {"name": "target", "args": {"tags": ["default", "scorer"]}} - - For multiple params on one initializer, separate with semicolons: "name:key1=val1;key2=val2" - For multiple initializers with params, space-separate them: "target:tags=a,b dataset:mode=strict" - - Args: - arg: The CLI argument string. - - Returns: - str | dict[str, Any]: A plain name string, or a dict with 'name' and 'args' keys. - - Raises: - ValueError: If the argument format is invalid. - """ - if ":" not in arg: - return arg - - name, params_str = arg.split(":", 1) - if not name: - raise ValueError(f"Invalid initializer argument '{arg}': missing name before ':'") - - args: dict[str, list[str]] = {} - for pair in params_str.split(";"): - pair = pair.strip() - if not pair: - continue - if "=" not in pair: - raise ValueError(f"Invalid initializer parameter '{pair}' in '{arg}': expected key=value format") - key, value = pair.split("=", 1) - key = key.strip() - if not key: - raise ValueError(f"Invalid initializer parameter in '{arg}': empty key") - args[key] = [v.strip() for v in value.split(",")] - - if args: - return {"name": name, "args": args} - return name - - -def parse_run_arguments(*, args_string: str) -> dict[str, Any]: - """ - Parse run command arguments from a string (for shell mode). - - Args: - args_string: Space-separated argument string (e.g., "scenario_name --initializers foo --strategies bar"). - - Returns: - Dictionary with parsed arguments: - - scenario_name: str - - initializers: Optional[list[str]] - - initialization_scripts: Optional[list[str]] - - scenario_strategies: Optional[list[str]] - - max_concurrency: Optional[int] - - max_retries: Optional[int] - - memory_labels: Optional[dict[str, str]] - - database: Optional[str] - - log_level: Optional[int] - - dataset_names: Optional[list[str]] - - max_dataset_size: Optional[int] - - Raises: - ValueError: If parsing or validation fails. - """ - parts = args_string.split() - - if not parts: - raise ValueError("No scenario name provided") - - result: dict[str, Any] = { - "scenario_name": parts[0], - "initializers": None, - "initialization_scripts": None, - "env_files": None, - "scenario_strategies": None, - "max_concurrency": None, - "max_retries": None, - "memory_labels": None, - "database": None, - "log_level": None, - "dataset_names": None, - "max_dataset_size": None, - } - - i = 1 - while i < len(parts): - if parts[i] == "--initializers": - # Collect initializers until next flag, parsing name:key=val syntax - result["initializers"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initializers"].append(_parse_initializer_arg(parts[i])) - i += 1 - elif parts[i] == "--initialization-scripts": - # Collect script paths until next flag - result["initialization_scripts"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["initialization_scripts"].append(parts[i]) - i += 1 - elif parts[i] == "--env-files": - # Collect env file paths until next flag - result["env_files"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["env_files"].append(parts[i]) - i += 1 - elif parts[i] in ("--strategies", "-s"): - # Collect strategies until next flag - result["scenario_strategies"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--") and parts[i] != "-s": - result["scenario_strategies"].append(parts[i]) - i += 1 - elif parts[i] == "--max-concurrency": - i += 1 - if i >= len(parts): - raise ValueError("--max-concurrency requires a value") - result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1) - i += 1 - elif parts[i] == "--max-retries": - i += 1 - if i >= len(parts): - raise ValueError("--max-retries requires a value") - result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0) - i += 1 - elif parts[i] == "--memory-labels": - i += 1 - if i >= len(parts): - raise ValueError("--memory-labels requires a value") - result["memory_labels"] = parse_memory_labels(parts[i]) - i += 1 - elif parts[i] == "--database": - i += 1 - if i >= len(parts): - raise ValueError("--database requires a value") - result["database"] = validate_database(database=parts[i]) - i += 1 - elif parts[i] == "--log-level": - i += 1 - if i >= len(parts): - raise ValueError("--log-level requires a value") - result["log_level"] = validate_log_level(log_level=parts[i]) - i += 1 - elif parts[i] == "--dataset-names": - # Collect dataset names until next flag - result["dataset_names"] = [] - i += 1 - while i < len(parts) and not parts[i].startswith("--"): - result["dataset_names"].append(parts[i]) - i += 1 - elif parts[i] == "--max-dataset-size": - i += 1 - if i >= len(parts): - raise ValueError("--max-dataset-size requires a value") - result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) - i += 1 - else: - logger.warning(f"Unknown argument: {parts[i]}") - i += 1 - - return result diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index d674d5617..28d713b93 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -145,11 +145,12 @@ def __init__( def _background_init(self) -> None: """Import heavy modules and initialize PyRIT in the background.""" try: + from pyrit.cli import frontend_core as fc + + self._fc = fc if self._deprecated_context is not None: self.context = self._deprecated_context else: - from pyrit.cli import frontend_core as fc - self.context = fc.FrontendCore(**self._context_kwargs) self.default_database = self.context._database self.default_log_level = self.context._log_level @@ -198,24 +199,18 @@ def cmdloop(self, intro: Optional[str] = None) -> None: def do_list_scenarios(self, arg: str) -> None: """List all available scenarios.""" self._ensure_initialized() - from pyrit.cli import frontend_core - try: - asyncio.run(frontend_core.print_scenarios_list_async(context=self.context)) + asyncio.run(self._fc.print_scenarios_list_async(context=self.context)) except Exception as e: print(f"Error listing scenarios: {e}") def do_list_initializers(self, arg: str) -> None: """List all available initializers.""" self._ensure_initialized() - from pyrit.cli import frontend_core - try: # Discover from scenarios directory by default (same as scan) - discovery_path = frontend_core.get_default_initializer_discovery_path() - asyncio.run( - frontend_core.print_initializers_list_async(context=self.context, discovery_path=discovery_path) - ) + discovery_path = self._fc.get_default_initializer_discovery_path() + asyncio.run(self._fc.print_initializers_list_async(context=self.context, discovery_path=discovery_path)) except Exception as e: print(f"Error listing initializers: {e}") @@ -261,25 +256,24 @@ def do_run(self, line: str) -> None: Initializers are specified per-run to allow different setups for different scenarios. """ self._ensure_initialized() - from pyrit.cli import frontend_core if not line.strip(): print("Error: Specify a scenario name") print("\nUsage: run [options]") print("\nNote: Every scenario requires an initializer.") print("\nOptions:") - print(f" --initializers ... {frontend_core.ARG_HELP['initializers']} (REQUIRED)") + print(f" --initializers ... {self._fc.ARG_HELP['initializers']} (REQUIRED)") print( - f" --initialization-scripts <...> {frontend_core.ARG_HELP['initialization_scripts']}" + f" --initialization-scripts <...> {self._fc.ARG_HELP['initialization_scripts']}" " (alternative to --initializers)" ) - print(f" --strategies, -s ... {frontend_core.ARG_HELP['scenario_strategies']}") - print(f" --max-concurrency {frontend_core.ARG_HELP['max_concurrency']}") - print(f" --max-retries {frontend_core.ARG_HELP['max_retries']}") - print(f" --memory-labels {frontend_core.ARG_HELP['memory_labels']}") + print(f" --strategies, -s ... {self._fc.ARG_HELP['scenario_strategies']}") + print(f" --max-concurrency {self._fc.ARG_HELP['max_concurrency']}") + print(f" --max-retries {self._fc.ARG_HELP['max_retries']}") + print(f" --memory-labels {self._fc.ARG_HELP['memory_labels']}") print( f" --database Override default database" - f" ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL})" + f" ({self._fc.IN_MEMORY}, {self._fc.SQLITE}, {self._fc.AZURE_SQL})" ) print( " --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)" @@ -291,7 +285,7 @@ def do_run(self, line: str) -> None: # Parse arguments using shared parser try: - args = frontend_core.parse_run_arguments(args_string=line) + args = self._fc.parse_run_arguments(args_string=line) except ValueError as e: print(f"Error: {e}") return @@ -300,9 +294,7 @@ def do_run(self, line: str) -> None: resolved_scripts = None if args["initialization_scripts"]: try: - resolved_scripts = frontend_core.resolve_initialization_scripts( - script_paths=args["initialization_scripts"] - ) + resolved_scripts = self._fc.resolve_initialization_scripts(script_paths=args["initialization_scripts"]) except FileNotFoundError as e: print(f"Error: {e}") return @@ -311,7 +303,7 @@ def do_run(self, line: str) -> None: resolved_env_files: Optional[list[Path]] = None if args["env_files"]: try: - resolved_env_files = list(frontend_core.resolve_env_files(env_file_paths=args["env_files"])) + resolved_env_files = list(self._fc.resolve_env_files(env_file_paths=args["env_files"])) except ValueError as e: print(f"Error: {e}") return @@ -320,7 +312,7 @@ def do_run(self, line: str) -> None: resolved_env_files = list(self.default_env_files) if self.default_env_files else None # Create a context for this run with overrides - run_context = frontend_core.FrontendCore( + run_context = self._fc.FrontendCore( database=args["database"] or self.default_database, initialization_scripts=resolved_scripts, initializer_names=args["initializers"], @@ -334,7 +326,7 @@ def do_run(self, line: str) -> None: try: result = asyncio.run( - frontend_core.run_scenario_async( + self._fc.run_scenario_async( scenario_name=args["scenario_name"], context=run_context, scenario_strategies=args["scenario_strategies"], @@ -435,7 +427,6 @@ def do_help(self, arg: str) -> None: """Show help. Usage: help [command].""" if not arg: self._ensure_initialized() - from pyrit.cli import frontend_core # Show general help super().do_help(arg) @@ -456,7 +447,7 @@ def do_help(self, arg: str) -> None: print("Run Command Options (specified when running scenarios):") print("=" * 70) print(" --initializers [ ...] (REQUIRED)") - print(f" {frontend_core.ARG_HELP['initializers']}") + print(f" {self._fc.ARG_HELP['initializers']}") print(" Every scenario requires at least one initializer") print(" Example: run foundry --initializers openai_objective_target load_default_datasets") print(" With params: run foundry --initializers target:tags=default,scorer") @@ -465,21 +456,21 @@ def do_help(self, arg: str) -> None: ) print() print(" --initialization-scripts [ ...] (Alternative to --initializers)") - print(f" {frontend_core.ARG_HELP['initialization_scripts']}") + print(f" {self._fc.ARG_HELP['initialization_scripts']}") print(" Example: run foundry --initialization-scripts ./my_init.py") print() print(" --strategies, -s [ ...]") - print(f" {frontend_core.ARG_HELP['scenario_strategies']}") + print(f" {self._fc.ARG_HELP['scenario_strategies']}") print(" Example: run garak.encoding --strategies base64 rot13") print() print(" --max-concurrency ") - print(f" {frontend_core.ARG_HELP['max_concurrency']}") + print(f" {self._fc.ARG_HELP['max_concurrency']}") print() print(" --max-retries ") - print(f" {frontend_core.ARG_HELP['max_retries']}") + print(f" {self._fc.ARG_HELP['max_retries']}") print() print(" --memory-labels ") - print(f" {frontend_core.ARG_HELP['memory_labels']}") + print(f" {self._fc.ARG_HELP['memory_labels']}") print(' Example: run foundry --memory-labels \'{"env":"test"}\'') print() print("Start the shell like:") @@ -546,6 +537,8 @@ def main() -> int: """ import argparse + from pyrit.cli._cli_args import ARG_HELP, AZURE_SQL, IN_MEMORY, SQLITE + parser = argparse.ArgumentParser( prog="pyrit_shell", description="PyRIT Interactive Shell - Load modules once, run commands instantly", @@ -554,20 +547,16 @@ def main() -> int: parser.add_argument( "--config-file", type=Path, - help=( - "Path to a YAML configuration file. Allows specifying database, initializers, " - "initialization scripts, and env files. CLI arguments override config file values. " - "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." - ), + help=ARG_HELP["config_file"], ) parser.add_argument( "--database", - choices=["InMemory", "SQLite", "AzureSQL"], + choices=[IN_MEMORY, SQLITE, AZURE_SQL], default=None, help=( - "Default database type to use (InMemory, SQLite, AzureSQL)" - " (defaults to config file value, or SQLite if not specified)" + f"Default database type to use ({IN_MEMORY}, {SQLITE}, {AZURE_SQL})" + f" (defaults to config file value, or {SQLITE} if not specified)" ), ) @@ -586,7 +575,7 @@ def main() -> int: "--env-files", type=str, nargs="+", - help="Environment files to load in order (default for all runs, can be overridden per-run)", + help=ARG_HELP["env_files"], ) parser.add_argument( diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 1773278a4..651a7933a 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -49,6 +49,9 @@ def shell(): with patch.object(pyrit_shell.PyRITShell, "_background_init"): s = pyrit_shell.PyRITShell() # Manually set the state that _background_init would have set + from pyrit.cli import frontend_core as fc_module + + s._fc = fc_module s.context = mock_context s.default_database = mock_context._database s.default_log_level = mock_context._log_level From f7ced9a00b17fc347183529c49f8d33c9ffb00d3 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 18 Mar 2026 12:17:31 -0700 Subject: [PATCH 13/14] Address code review feedback: fix race-safe logging, fast help, type correctness - Use logging.disable() instead of setLevel() to avoid race with background thread - do_help() no longer blocks on full init; imports ARG_HELP from lightweight _cli_args - Convert string log_level to int via validate_log_level() in main() - Validate env file existence with resolve_env_files() in main() - Move print_deprecation_message to top-level import - Fix parse_run_arguments docstring: initializers type is str | dict - Include stderr in startup integration test failure message Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/_cli_args.py | 2 +- pyrit/cli/pyrit_shell.py | 50 +++++++++++-------- .../cli/test_pyrit_shell_startup.py | 5 +- tests/unit/cli/test_pyrit_shell.py | 5 +- 4 files changed, 35 insertions(+), 27 deletions(-) diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index 29d1810db..f024e6b99 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -349,7 +349,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: Returns: Dictionary with parsed arguments: - scenario_name: str - - initializers: Optional[list[str]] + - initializers: Optional[list[str | dict[str, Any]]] - initialization_scripts: Optional[list[str]] - scenario_strategies: Optional[list[str]] - max_concurrency: Optional[int] diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index 28d713b93..f6020d319 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -25,6 +25,7 @@ from pyrit.models.scenario_result import ScenarioResult from pyrit.cli import _banner as banner +from pyrit.common.deprecation import print_deprecation_message class PyRITShell(cmd.Cmd): @@ -116,8 +117,6 @@ def __init__( "Cannot pass 'context' together with FrontendCore keyword arguments " f"({', '.join(self._context_kwargs)}). Use one or the other." ) - from pyrit.common.deprecation import print_deprecation_message - print_deprecation_message( old_item="PyRITShell(context=...)", new_item="PyRITShell(database=..., log_level=..., ...)", @@ -180,13 +179,12 @@ def cmdloop(self, intro: Optional[str] = None) -> None: # Play animation immediately while background init continues. # Suppress logging during the animation so log lines don't corrupt # the ANSI cursor-positioned frames. - root_logger = logging.getLogger() - prev_level = root_logger.level - root_logger.setLevel(logging.CRITICAL) + prev_disable = logging.root.manager.disable + logging.disable(logging.CRITICAL) try: intro = banner.play_animation(no_animation=self._no_animation) finally: - root_logger.setLevel(prev_level) + logging.disable(prev_disable) # If init already failed while the animation played, surface it now. if self._init_complete.is_set(): @@ -426,9 +424,9 @@ def do_print_scenario(self, arg: str) -> None: def do_help(self, arg: str) -> None: """Show help. Usage: help [command].""" if not arg: - self._ensure_initialized() + from pyrit.cli._cli_args import ARG_HELP, AZURE_SQL, IN_MEMORY, SQLITE - # Show general help + # Show general help (no full init needed — ARG_HELP is lightweight) super().do_help(arg) print("\n" + "=" * 70) print("Shell Startup Options:") @@ -447,7 +445,7 @@ def do_help(self, arg: str) -> None: print("Run Command Options (specified when running scenarios):") print("=" * 70) print(" --initializers [ ...] (REQUIRED)") - print(f" {self._fc.ARG_HELP['initializers']}") + print(f" {ARG_HELP['initializers']}") print(" Every scenario requires at least one initializer") print(" Example: run foundry --initializers openai_objective_target load_default_datasets") print(" With params: run foundry --initializers target:tags=default,scorer") @@ -456,23 +454,26 @@ def do_help(self, arg: str) -> None: ) print() print(" --initialization-scripts [ ...] (Alternative to --initializers)") - print(f" {self._fc.ARG_HELP['initialization_scripts']}") + print(f" {ARG_HELP['initialization_scripts']}") print(" Example: run foundry --initialization-scripts ./my_init.py") print() print(" --strategies, -s [ ...]") - print(f" {self._fc.ARG_HELP['scenario_strategies']}") + print(f" {ARG_HELP['scenario_strategies']}") print(" Example: run garak.encoding --strategies base64 rot13") print() print(" --max-concurrency ") - print(f" {self._fc.ARG_HELP['max_concurrency']}") + print(f" {ARG_HELP['max_concurrency']}") print() print(" --max-retries ") - print(f" {self._fc.ARG_HELP['max_retries']}") + print(f" {ARG_HELP['max_retries']}") print() print(" --memory-labels ") - print(f" {self._fc.ARG_HELP['memory_labels']}") + print(f" {ARG_HELP['memory_labels']}") print(' Example: run foundry --memory-labels \'{"env":"test"}\'') print() + print(f" --database Override ({IN_MEMORY}, {SQLITE}, {AZURE_SQL})") + print(" --log-level Override (DEBUG, INFO, WARNING, ERROR, CRITICAL)") + print() print("Start the shell like:") print(" pyrit_shell") print(" pyrit_shell --database InMemory --log-level DEBUG") @@ -537,7 +538,7 @@ def main() -> int: """ import argparse - from pyrit.cli._cli_args import ARG_HELP, AZURE_SQL, IN_MEMORY, SQLITE + from pyrit.cli._cli_args import ARG_HELP, AZURE_SQL, IN_MEMORY, SQLITE, validate_log_level parser = argparse.ArgumentParser( prog="pyrit_shell", @@ -587,20 +588,25 @@ def main() -> int: args = parser.parse_args() - # Resolve env file paths (lightweight — no heavy imports needed). + # Resolve and validate env file paths (lightweight — no heavy imports needed). env_files: Optional[list[Path]] = None if args.env_files: - env_files = [Path(p).resolve() for p in args.env_files] + from pyrit.cli._cli_args import resolve_env_files + + try: + env_files = resolve_env_files(env_file_paths=args.env_files) + except ValueError as e: + print(f"Error: {e}") + return 1 # Play the banner immediately, before heavy imports. # Suppress logging so background-thread output doesn't corrupt the animation. - root_logger = logging.getLogger() - prev_level = root_logger.level - root_logger.setLevel(logging.CRITICAL) + prev_disable = logging.root.manager.disable + logging.disable(logging.CRITICAL) try: intro = banner.play_animation(no_animation=args.no_animation) finally: - root_logger.setLevel(prev_level) + logging.disable(prev_disable) # Create shell with deferred initialization — the background thread # will import frontend_core, create the FrontendCore context, and call @@ -613,7 +619,7 @@ def main() -> int: initialization_scripts=None, initializer_names=None, env_files=env_files, - log_level=args.log_level, + log_level=validate_log_level(log_level=args.log_level), ) shell.cmdloop(intro=intro) return 0 diff --git a/tests/integration/cli/test_pyrit_shell_startup.py b/tests/integration/cli/test_pyrit_shell_startup.py index a16fcaf8b..ea43af368 100644 --- a/tests/integration/cli/test_pyrit_shell_startup.py +++ b/tests/integration/cli/test_pyrit_shell_startup.py @@ -44,7 +44,8 @@ def test_pyrit_shell_module_imports_within_budget() -> None: elapsed_str = result.stdout.strip() assert result.returncode == 0, ( - f"pyrit_shell module import took {elapsed_str}s, " + f"pyrit_shell module import took {elapsed_str or '?'}s, " f"exceeding the {_MAX_IMPORT_SECONDS}s budget. " - "Check for heavy top-level imports in pyrit_shell.py." + f"Check for heavy top-level imports in pyrit_shell.py.\n" + f"returncode={result.returncode}\nstderr: {result.stderr.strip()}" ) diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py index 651a7933a..58043c550 100644 --- a/tests/unit/cli/test_pyrit_shell.py +++ b/tests/unit/cli/test_pyrit_shell.py @@ -6,6 +6,7 @@ """ import cmd +import logging from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -637,7 +638,7 @@ def test_main_default_args(self, mock_play: MagicMock, mock_shell_class: MagicMo assert result == 0 call_kwargs = mock_shell_class.call_args[1] assert call_kwargs["database"] is None - assert call_kwargs["log_level"] == "WARNING" + assert call_kwargs["log_level"] == logging.WARNING mock_shell.cmdloop.assert_called_once() @patch("pyrit.cli.pyrit_shell.PyRITShell") @@ -666,7 +667,7 @@ def test_main_with_log_level_arg(self, mock_play: MagicMock, mock_shell_class: M assert result == 0 call_kwargs = mock_shell_class.call_args[1] - assert call_kwargs["log_level"] == "DEBUG" + assert call_kwargs["log_level"] == logging.DEBUG @patch("pyrit.cli.pyrit_shell.PyRITShell") @patch("pyrit.cli._banner.play_animation", return_value="") From b18558b614268d8573ea3a97857039c0546e013a Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Wed, 18 Mar 2026 13:32:42 -0700 Subject: [PATCH 14/14] Raise ValueError on unknown run arguments instead of logging warning Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/_cli_args.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyrit/cli/_cli_args.py b/pyrit/cli/_cli_args.py index f024e6b99..8472a5afe 100644 --- a/pyrit/cli/_cli_args.py +++ b/pyrit/cli/_cli_args.py @@ -457,8 +457,7 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: result["max_dataset_size"] = validate_integer(parts[i], name="--max-dataset-size", min_value=1) i += 1 else: - _logger.warning(f"Unknown argument: {parts[i]}") - i += 1 + raise ValueError(f"Unknown argument: {parts[i]}") return result