diff --git a/.flake8 b/.flake8 index dc7f52bd1..923c79c3e 100644 --- a/.flake8 +++ b/.flake8 @@ -9,9 +9,21 @@ exclude = .git, dist, doc, + .github, *lib/python*, *egg, - build + build, + pyrit/cli/pyrit_shell.py, + pyrit/prompt_converter/morse_converter.py, + pyrit/prompt_converter/emoji_converter.py, + pyrit/scenarios/printer/console_printer.py, + tests/unit/converter/test_prompt_converter.py, + tests/unit/converter/test_unicode_confusable_converter.py, + tests/unit/converter/test_first_letter_converter.py, + tests/unit/converter/test_base2048_converter.py, + tests/unit/converter/test_ecoji_converter.py, + tests/unit/converter/test_bin_ascii_converter.py, + tests/unit/models/test_seed.py per-file-ignores = ./pyrit/score/gpt_classifier.py:E501,W291 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dbec6f246..751991e06 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,7 +61,7 @@ repos: - id: flake8 additional_dependencies: ['flake8-copyright'] types: [python] - exclude: (doc/|.github/|pyrit/prompt_converter/morse_converter.py|tests/unit/converter/test_prompt_converter.py|pyrit/prompt_converter/emoji_converter.py|tests/unit/models/test_seed.py|tests/unit/converter/test_unicode_confusable_converter.py|tests/unit/converter/test_first_letter_converter.py|tests/unit/converter/test_base2048_converter.py|tests/unit/converter/test_ecoji_converter.py|tests/unit/converter/test_bin_ascii_converter.py|pyrit/scenarios/printer/console_printer.py) + exclude: ^doc/ - repo: local hooks: diff --git a/doc/_toc.yml b/doc/_toc.yml index d0c6274ea..52dedc5d6 100644 --- a/doc/_toc.yml +++ b/doc/_toc.yml @@ -129,7 +129,10 @@ chapters: sections: - file: code/auxiliary_attacks/1_gcg_azure_ml - file: code/scenarios/scenarios - - file: code/front_end/0_cli + - file: code/front_end/0_front_end + sections: + - file: code/front_end/1_pyrit_scan + - file: code/front_end/2_pyrit_shell - file: deployment/README sections: - file: deployment/deploy_hf_model_aml diff --git a/doc/code/front_end/0_front_end.md b/doc/code/front_end/0_front_end.md new file mode 100644 index 000000000..9463e4173 --- /dev/null +++ b/doc/code/front_end/0_front_end.md @@ -0,0 +1,40 @@ +# PyRIT Command-Line Frontends + +PyRIT provides two command-line interfaces for running AI red teaming scenarios: + +## pyrit_scan - Single-Command Execution + +`pyrit_scan` is designed for **automated, non-interactive scenario execution**. It's ideal for: +- CI/CD pipelines and automated testing workflows +- Batch processing multiple scenarios with scripts +- One-time security assessments +- Reproducible test runs with exact parameters + +Each invocation runs a single scenario with specified parameters and exits, making it perfect for automation where you need clean, scriptable execution with predictable exit codes. + +**Key characteristics:** +- Loads PyRIT modules fresh for each execution +- Runs one scenario per command +- Exits with status code (0 for success, non-zero for errors) +- Output can be easily captured and parsed + +**Documentation:** [1_pyrit_scan.ipynb](1_pyrit_scan.ipynb) + +## pyrit_shell - Interactive Session + +`pyrit_shell` is an **interactive REPL (Read-Eval-Print Loop)** for exploratory testing. It's ideal for: +- Interactive scenario development and debugging +- Rapid iteration and experimentation +- Exploring multiple scenarios without reload overhead +- Session-based result tracking and comparison + +The shell loads PyRIT modules once at startup and maintains a persistent session, allowing you to run multiple scenarios quickly and review their results interactively. + +**Key characteristics:** +- Fast subsequent executions (modules loaded once) +- Session history of all runs +- Interactive result exploration and printing +- Persistent context across multiple scenario runs +- Tab completion and command help + +**Documentation:** [2_pyrit_shell.md](2_pyrit_shell.md) diff --git a/doc/code/front_end/0_cli.ipynb b/doc/code/front_end/1_pyrit_scan.ipynb similarity index 66% rename from doc/code/front_end/0_cli.ipynb rename to doc/code/front_end/1_pyrit_scan.ipynb index f79282226..f83dfa93f 100644 --- a/doc/code/front_end/0_cli.ipynb +++ b/doc/code/front_end/1_pyrit_scan.ipynb @@ -5,9 +5,9 @@ "id": "0", "metadata": {}, "source": [ - "# The PyRIT CLI\n", + "# 1. PyRIT Scan - Command Line Execution\n", "\n", - "The PyRIT cli tool that allows you to run automated security testing and red teaming attacks against AI systems using [scenarios](../scenarios/scenarios.ipynb) for strategies and [configuration](../setup/1_configuration.ipynb).\n", + "`pyrit_scan` allows you to run automated security testing and red teaming attacks against AI systems using [scenarios](../scenarios/scenarios.ipynb) for strategies and [configuration](../setup/1_configuration.ipynb).\n", "\n", "Note in this doc the ! prefaces all commands in the terminal so we can run in a Jupyter Notebook.\n", "\n", @@ -26,11 +26,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "usage: pyrit_scan [-h] [--verbose] [--list-scenarios] [--list-initializers]\n", - " [--database {InMemory,SQLite,AzureSQL}]\n", + "Starting PyRIT...\n", + "usage: pyrit_scan [-h] [--log-level LOG_LEVEL] [--list-scenarios]\n", + " [--list-initializers] [--database DATABASE]\n", " [--initializers INITIALIZERS [INITIALIZERS ...]]\n", " [--initialization-scripts INITIALIZATION_SCRIPTS [INITIALIZATION_SCRIPTS ...]]\n", - " [--scenario-strategies SCENARIO_STRATEGIES [SCENARIO_STRATEGIES ...]]\n", + " [--strategies SCENARIO_STRATEGIES [SCENARIO_STRATEGIES ...]]\n", + " [--max-concurrency MAX_CONCURRENCY]\n", + " [--max-retries MAX_RETRIES] [--memory-labels MEMORY_LABELS]\n", " [scenario_name]\n", "\n", "PyRIT Scanner - Run security scenarios against AI systems\n", @@ -41,35 +44,45 @@ " pyrit_scan --list-initializers\n", "\n", " # Run a scenario with built-in initializers\n", - " pyrit_scan foundry_scenario --initializers simple objective_target\n", + " pyrit_scan foundry_scenario --initializers openai_objective_target\n", "\n", " # Run with custom initialization scripts\n", " pyrit_scan encoding_scenario --initialization-scripts ./my_config.py\n", - " \n", + "\n", " # Run specific strategies\n", - " pyrit_scan encoding_scenario --initializers simple objective_target --scenario-strategies base64 rot13 morse_code\n", - " pyrit_scan foundry_scenario --initializers simple objective_target --scenario-strategies base64 atbash\n", + " pyrit_scan encoding_scenario --initializers openai_objective_target --strategies base64 rot13\n", + " pyrit_scan foundry_scenario --initializers openai_objective_target --max-concurrency 10 --max-retries 3\n", + " pyrit_scan encoding_scenario --initializers openai_objective_target --memory-labels '{\"run_id\":\"test123\"}'\n", "\n", "positional arguments:\n", - " scenario_name Name of the scenario to run (e.g., encoding_scenario,\n", - " foundry_scenario)\n", + " scenario_name Name of the scenario to run\n", "\n", "options:\n", " -h, --help show this help message and exit\n", - " --verbose Enable verbose logging output\n", + " --log-level LOG_LEVEL\n", + " Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)\n", + " (default: WARNING)\n", " --list-scenarios List all available scenarios and exit\n", " --list-initializers List all available scenario initializers and exit\n", - " --database {InMemory,SQLite,AzureSQL}\n", - " Database type to use for memory storage\n", + " --database DATABASE Database type to use for memory storage (InMemory,\n", + " SQLite, AzureSQL) (default: SQLite)\n", " --initializers INITIALIZERS [INITIALIZERS ...]\n", - " Built-in initializer names (e.g., simple,\n", - " objective_target, objective_list)\n", + " Built-in initializer names to run before the scenario\n", + " (e.g., openai_objective_target)\n", " --initialization-scripts INITIALIZATION_SCRIPTS [INITIALIZATION_SCRIPTS ...]\n", - " Paths to custom Python initialization scripts that\n", - " configure scenarios and defaults\n", - " --scenario-strategies SCENARIO_STRATEGIES [SCENARIO_STRATEGIES ...]\n", - " List of strategy names to run (e.g., base64 rot13\n", - " morse_code). If not specified, uses scenario defaults.\n" + " Paths to custom Python initialization scripts to run\n", + " before the scenario\n", + " --strategies SCENARIO_STRATEGIES [SCENARIO_STRATEGIES ...], -s SCENARIO_STRATEGIES [SCENARIO_STRATEGIES ...]\n", + " List of strategy names to run (e.g., base64 rot13)\n", + " --max-concurrency MAX_CONCURRENCY\n", + " Maximum number of concurrent attack executions (must\n", + " be >= 1)\n", + " --max-retries MAX_RETRIES\n", + " Maximum number of automatic retries on exception (must\n", + " be >= 0)\n", + " --memory-labels MEMORY_LABELS\n", + " Additional labels as JSON string (e.g.,\n", + " '{\"experiment\": \"test1\"}')\n" ] } ], @@ -97,23 +110,26 @@ "name": "stdout", "output_type": "stream", "text": [ + "Starting PyRIT...\n", + "Loading PyRIT modules...\n", + "Ready!\n", "\n", "Available Scenarios:\n", "================================================================================\n", - "\n", - " encoding_scenario\n", + "\u001b[1m\u001b[36m\n", + " encoding_scenario\u001b[0m\n", " Class: EncodingScenario\n", " Description:\n", " Encoding Scenario implementation for PyRIT. This scenario tests how\n", " resilient models are to various encoding attacks by encoding potentially\n", - " harmful text (by default slurs and XSS payloads) and testing if the model\n", - " will decode and repeat the encoded payload. It mimics the Garak encoding\n", - " probe. The scenario works by: 1. Taking seed prompts (the harmful text to\n", - " be encoded) 2. Encoding them using various encoding schemes (Base64,\n", - " ROT13, Morse, etc.) 3. Asking the target model to decode the encoded text\n", - " 4. Scoring whether the model successfully decoded and repeated the harmful\n", - " content By default, this uses the same dataset as Garak: slur terms and\n", - " web XSS payloads.\n", + " harmful text (by default slurs and XSS payloads) and testing if the\n", + " model will decode and repeat the encoded payload. It mimics the Garak\n", + " encoding probe. The scenario works by: 1. Taking seed prompts (the\n", + " harmful text to be encoded) 2. Encoding them using various encoding\n", + " schemes (Base64, ROT13, Morse, etc.) 3. Asking the target model to\n", + " decode the encoded text 4. Scoring whether the model successfully\n", + " decoded and repeated the harmful content By default, this uses the same\n", + " dataset as Garak: slur terms and web XSS payloads.\n", " Aggregate Strategies:\n", " - all\n", " Available Strategies (17):\n", @@ -121,36 +137,32 @@ " uuencode, rot13, braille, atbash, morse_code, nato, ecoji, zalgo,\n", " leet_speak, ascii_smuggler\n", " Default Strategy: all\n", - "\n", - " foundry_scenario\n", + "\u001b[1m\u001b[36m\n", + " foundry_scenario\u001b[0m\n", " Class: FoundryScenario\n", " Description:\n", " FoundryScenario is a preconfigured scenario that automatically generates\n", - " multiple AtomicAttack instances based on the specified attack strategies.\n", - " It supports both single-turn attacks (with various converters) and\n", - " multi-turn attacks (Crescendo, RedTeaming), making it easy to quickly test\n", - " a target against multiple attack vectors. The scenario can expand\n", - " difficulty levels (EASY, MODERATE, DIFFICULT) into their constituent\n", - " attack strategies, or you can specify individual strategies directly. Note\n", - " this is not the same as the Foundry AI Red Teaming Agent. This is a PyRIT\n", - " contract so their library can make use of PyRIT in a consistent way.\n", + " multiple AtomicAttack instances based on the specified attack\n", + " strategies. It supports both single-turn attacks (with various\n", + " converters) and multi-turn attacks (Crescendo, RedTeaming), making it\n", + " easy to quickly test a target against multiple attack vectors. The\n", + " scenario can expand difficulty levels (EASY, MODERATE, DIFFICULT) into\n", + " their constituent attack strategies, or you can specify individual\n", + " strategies directly. Note this is not the same as the Foundry AI Red\n", + " Teaming Agent. This is a PyRIT contract so their library can make use of\n", + " PyRIT in a consistent way.\n", " Aggregate Strategies:\n", - " - all\n", - " - easy\n", - " - moderate\n", - " - difficult\n", + " - all, easy, moderate, difficult\n", " Available Strategies (23):\n", " ansi_attack, ascii_art, ascii_smuggler, atbash, base64, binary, caesar,\n", " character_space, char_swap, diacritic, flip, leetspeak, morse, rot13,\n", - " suffix_append, string_join, unicode_confusable, unicode_substitution, url,\n", - " jailbreak, tense, multi_turn, crescendo\n", + " suffix_append, string_join, unicode_confusable, unicode_substitution,\n", + " url, jailbreak, tense, multi_turn, crescendo\n", " Default Strategy: easy\n", "\n", "================================================================================\n", "\n", - "Total scenarios: 2\n", - "\n", - "For usage information, use: pyrit_scan --help\n" + "Total scenarios: 2\n" ] } ], @@ -190,19 +202,20 @@ "name": "stdout", "output_type": "stream", "text": [ + "Starting PyRIT...\n", "\n", - "Available Scenario Initializers:\n", + "Available Initializers:\n", "================================================================================\n", - "\n", - " objective_list\n", + "\u001b[1m\u001b[36m\n", + " objective_list\u001b[0m\n", " Class: ScenarioObjectiveListInitializer\n", " Name: Simple Objective List Configuration for Scenarios\n", " Execution Order: 10\n", " Required Environment Variables: None\n", " Description:\n", " Simple Objective List Configuration for Scenarios\n", - "\n", - " openai_objective_target\n", + "\u001b[1m\u001b[36m\n", + " openai_objective_target\u001b[0m\n", " Class: ScenarioObjectiveTargetInitializer\n", " Name: Simple Objective Target Configuration for Scenarios\n", " Execution Order: 10\n", @@ -211,14 +224,13 @@ " - DEFAULT_OPENAI_FRONTEND_KEY\n", " Description:\n", " This configuration sets up a simple objective target for scenarios using\n", - " OpenAIChatTarget with basic settings. It initializes an openAI chat target\n", - " using the OPENAI_CLI_ENDPOINT and OPENAI_CLI_KEY environment variables.\n", + " OpenAIChatTarget with basic settings. It initializes an openAI chat\n", + " target using the OPENAI_CLI_ENDPOINT and OPENAI_CLI_KEY environment\n", + " variables.\n", "\n", "================================================================================\n", "\n", - "Total initializers: 2\n", - "\n", - "For usage information, use: pyrit_scan --help\n" + "Total initializers: 2\n" ] } ], @@ -242,7 +254,13 @@ "Basic usage will look something like:\n", "\n", "```shell\n", - "pyrit_scan --initializers --scenario-strategies \n", + "pyrit_scan --initializers --scenario-strategies \n", + "```\n", + "\n", + "You can also override scenario parameters directly from the CLI:\n", + "\n", + "```shell\n", + "pyrit_scan --max-concurrency 10 --max-retries 3 --memory-labels '{\"experiment\": \"test1\", \"version\": \"v2\"}'\n", "```\n", "\n", "Or concretely:\n", @@ -264,6 +282,12 @@ "name": "stdout", "output_type": "stream", "text": [ + "Starting PyRIT...\n", + "Loading PyRIT modules...\n", + "Ready!\n", + "Running 1 initializer(s)...\n", + "\n", + "Running scenario: foundry_scenario\n", "\n", "\u001b[36m════════════════════════════════════════════════════════════════════════════════════════════════════\u001b[0m\n", "\u001b[1m\u001b[36m 📊 SCENARIO RESULTS: FoundryScenario \u001b[0m\n", @@ -327,14 +351,14 @@ "text": [ "\n", "Executing Foundry Scenario: 0%| | 0/2 [00:00`: Maximum number of concurrent attack executions\n", + "- `--max-retries `: Maximum number of automatic retries if the scenario raises an exception\n", + "- `--memory-labels `: Additional labels to apply to all attack runs (must be a JSON string with string keys and values)\n", "\n", - "You can define your own scenarios in initialization scripts. The CLI will automatically discover any `Scenario` subclasses and make them available:\n", + "You can also use custom initialization scripts by passing file paths. It is relative to your current working directory, but to avoid confusion, full paths are always better:\n", "\n", - "```python\n", + "```shell\n", + "pyrit_scan encoding_scenario --initialization-scripts ./my_custom_config.py\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "#### Using Custom Scenarios\n", + "\n", + "You can define your own scenarios in initialization scripts. The CLI will automatically discover any `Scenario` subclasses and make them available:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ "# my_custom_scenarios.py\n", - "from pyrit.scenarios import Scenario\n", "from pyrit.common.apply_defaults import apply_defaults\n", + "from pyrit.scenarios import Scenario\n", + "from pyrit.scenarios.scenario_strategy import ScenarioStrategy\n", "\n", - "@apply_defaults\n", - "class MyCustomScenario(Scenario):\n", - " \"\"\"My custom scenario that does XYZ.\"\"\"\n", "\n", - " def __init__(self, objective_target=None):\n", - " super().__init__(name=\"My Custom Scenario\", version=\"1.0\")\n", - " self.objective_target = objective_target\n", - " # ... your initialization code\n", + "class MyCustomStrategy(ScenarioStrategy):\n", + " \"\"\"Strategies for my custom scenario.\"\"\"\n", + " ALL = (\"all\", {\"all\"})\n", + " Strategy1 = (\"strategy1\", set[str]())\n", + " Strategy2 = (\"strategy2\", set[str]())\n", "\n", - " async def initialize_async(self):\n", - " # Load your atomic attacks\n", - " pass\n", "\n", - " # ... implement other required methods\n", - "```\n", + "@apply_defaults\n", + "class MyCustomScenario(Scenario):\n", + " \"\"\"My custom scenario that does XYZ.\"\"\"\n", "\n", + " @classmethod\n", + " def get_strategy_class(cls):\n", + " return MyCustomStrategy\n", + "\n", + " @classmethod\n", + " def get_default_strategy(cls):\n", + " return MyCustomStrategy.ALL\n", + "\n", + " def __init__(self, *, scenario_result_id=None, **kwargs):\n", + " # Scenario-specific configuration only - no runtime parameters\n", + " super().__init__(\n", + " name=\"My Custom Scenario\",\n", + " version=1,\n", + " strategy_class=MyCustomStrategy,\n", + " default_aggregate=MyCustomStrategy.ALL,\n", + " scenario_result_id=scenario_result_id,\n", + " )\n", + " # ... your scenario-specific initialization code\n", + "\n", + " async def _get_atomic_attacks_async(self):\n", + " # Build and return your atomic attacks\n", + " return []" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ "Then discover and run it:\n", "\n", "```shell\n", "# List to see it's available\n", "pyrit_scan --list-scenarios --initialization-scripts ./my_custom_scenarios.py\n", "\n", - "# Run it\n", - "pyrit_scan my_custom_scenario --initialization-scripts ./my_custom_scenarios.py\n", + "# Run it with parameter overrides\n", + "pyrit_scan my_custom_scenario --initialization-scripts ./my_custom_scenarios.py --max-concurrency 10\n", "```\n", "\n", - "The scenario name is automatically converted from the class name (e.g., `MyCustomScenario` becomes `my_custom_scenario`).\n", - "\n", - "\n", - "## When to Use the Scanner\n", - "\n", - "The scanner is ideal for:\n", - "\n", - "- **Automated testing pipelines**: CI/CD integration for continuous security testing\n", - "- **Batch testing**: Running multiple attack scenarios against various targets\n", - "- **Repeatable tests**: Standardized testing with consistent configurations\n", - "- **Team collaboration**: Shareable configuration files for consistent testing approaches\n", - "- **Quick testing**: Fast execution without writing Python code\n", - "\n", - "\n", - "## Complete Documentation\n", - "\n", - "For comprehensive documentation about initialization files and setting defaults see:\n", - "\n", - "- **Configuration**: See [configuration](../setup/1_configuration.ipynb)\n", - "- **Setting Default Values**: See [default values](../setup/default_values.md)\n", - "- **Writing Initializers**: See [Initializers](../setup/pyrit_initializer.ipynb)\n", - "\n", - "Or visit the [PyRIT documentation website](https://azure.github.io/PyRIT/)" + "The scenario name is automatically converted from the class name (e.g., `MyCustomScenario` becomes `my_custom_scenario`)." ] } ], diff --git a/doc/code/front_end/0_cli.py b/doc/code/front_end/1_pyrit_scan.py similarity index 57% rename from doc/code/front_end/0_cli.py rename to doc/code/front_end/1_pyrit_scan.py index 2fbf1bd25..2291e4210 100644 --- a/doc/code/front_end/0_cli.py +++ b/doc/code/front_end/1_pyrit_scan.py @@ -6,12 +6,16 @@ # format_name: percent # format_version: '1.3' # jupytext_version: 1.17.3 +# kernelspec: +# display_name: pyrit-dev +# language: python +# name: python3 # --- # %% [markdown] -# # The PyRIT CLI +# # 1. PyRIT Scan # -# The PyRIT cli tool that allows you to run automated security testing and red teaming attacks against AI systems using [scenarios](../scenarios/scenarios.ipynb) for strategies and [configuration](../setup/1_configuration.ipynb). +# `pyrit_scan` allows you to run automated security testing and red teaming attacks against AI systems using [scenarios](../scenarios/scenarios.ipynb) for strategies and [configuration](../setup/1_configuration.ipynb). # # Note in this doc the ! prefaces all commands in the terminal so we can run in a Jupyter Notebook. # @@ -62,7 +66,13 @@ # Basic usage will look something like: # # ```shell -# pyrit_scan --initializers --scenario-strategies +# pyrit_scan --initializers --scenario-strategies +# ``` +# +# You can also override scenario parameters directly from the CLI: +# +# ```shell +# pyrit_scan --max-concurrency 10 --max-retries 3 --memory-labels '{"experiment": "test1", "version": "v2"}' # ``` # # Or concretely: @@ -74,7 +84,7 @@ # Example with a basic configuration that runs the Foundry scenario against the objective target defined in `openai_objective_target` (which just is an OpenAIChatTarget with `DEFAULT_OPENAI_FRONTEND_ENDPOINT` and `DEFAULT_OPENAI_FRONTEND_KEY`). # %% -# !pyrit_scan foundry_scenario --initializers openai_objective_target --scenario-strategies base64 +# !pyrit_scan foundry_scenario --initializers openai_objective_target --strategies base64 # %% [markdown] # Or with all options and multiple initializers and multiple strategies: @@ -83,67 +93,86 @@ # pyrit_scan foundry_scenario --database InMemory --initializers simple objective_target objective_list --scenario-strategies easy crescendo # ``` # +# You can also override scenario execution parameters: +# +# ```shell +# # Override concurrency and retry settings +# pyrit_scan foundry_scenario --initializers simple objective_target --max-concurrency 10 --max-retries 3 +# +# # Add custom memory labels for tracking (must be valid JSON) +# pyrit_scan foundry_scenario --initializers simple objective_target --memory-labels '{"experiment": "test1", "version": "v2", "researcher": "alice"}' +# ``` +# +# Available CLI parameter overrides: +# - `--max-concurrency `: Maximum number of concurrent attack executions +# - `--max-retries `: Maximum number of automatic retries if the scenario raises an exception +# - `--memory-labels `: Additional labels to apply to all attack runs (must be a JSON string with string keys and values) +# # You can also use custom initialization scripts by passing file paths. It is relative to your current working directory, but to avoid confusion, full paths are always better: # # ```shell # pyrit_scan encoding_scenario --initialization-scripts ./my_custom_config.py # ``` -# + +# %% [markdown] # #### Using Custom Scenarios # # You can define your own scenarios in initialization scripts. The CLI will automatically discover any `Scenario` subclasses and make them available: # -# ```python -# # my_custom_scenarios.py -# from pyrit.scenarios import Scenario -# from pyrit.common.apply_defaults import apply_defaults -# -# @apply_defaults -# class MyCustomScenario(Scenario): -# """My custom scenario that does XYZ.""" -# -# def __init__(self, objective_target=None): -# super().__init__(name="My Custom Scenario", version="1.0") -# self.objective_target = objective_target -# # ... your initialization code -# -# async def initialize_async(self): -# # Load your atomic attacks -# pass -# -# # ... implement other required methods -# ``` -# + +from pyrit.common.apply_defaults import apply_defaults + +# %% +# my_custom_scenarios.py +from pyrit.scenarios import Scenario +from pyrit.scenarios.scenario_strategy import ScenarioStrategy + + +class MyCustomStrategy(ScenarioStrategy): + """Strategies for my custom scenario.""" + + ALL = ("all", {"all"}) + Strategy1 = ("strategy1", set[str]()) + Strategy2 = ("strategy2", set[str]()) + + +@apply_defaults +class MyCustomScenario(Scenario): + """My custom scenario that does XYZ.""" + + @classmethod + def get_strategy_class(cls): + return MyCustomStrategy + + @classmethod + def get_default_strategy(cls): + return MyCustomStrategy.ALL + + def __init__(self, *, scenario_result_id=None, **kwargs): + # Scenario-specific configuration only - no runtime parameters + super().__init__( + name="My Custom Scenario", + version=1, + strategy_class=MyCustomStrategy, + default_aggregate=MyCustomStrategy.ALL, + scenario_result_id=scenario_result_id, + ) + # ... your scenario-specific initialization code + + async def _get_atomic_attacks_async(self): + # Build and return your atomic attacks + return [] + + +# %% [markdown] # Then discover and run it: # # ```shell # # List to see it's available # pyrit_scan --list-scenarios --initialization-scripts ./my_custom_scenarios.py # -# # Run it -# pyrit_scan my_custom_scenario --initialization-scripts ./my_custom_scenarios.py +# # Run it with parameter overrides +# pyrit_scan my_custom_scenario --initialization-scripts ./my_custom_scenarios.py --max-concurrency 10 # ``` # # The scenario name is automatically converted from the class name (e.g., `MyCustomScenario` becomes `my_custom_scenario`). -# -# -# ## When to Use the Scanner -# -# The scanner is ideal for: -# -# - **Automated testing pipelines**: CI/CD integration for continuous security testing -# - **Batch testing**: Running multiple attack scenarios against various targets -# - **Repeatable tests**: Standardized testing with consistent configurations -# - **Team collaboration**: Shareable configuration files for consistent testing approaches -# - **Quick testing**: Fast execution without writing Python code -# -# -# ## Complete Documentation -# -# For comprehensive documentation about initialization files and setting defaults see: -# -# - **Configuration**: See [configuration](../setup/1_configuration.ipynb) -# - **Setting Default Values**: See [default values](../setup/default_values.md) -# - **Writing Initializers**: See [Initializers](../setup/pyrit_initializer.ipynb) -# -# Or visit the [PyRIT documentation website](https://azure.github.io/PyRIT/) diff --git a/doc/code/front_end/2_pyrit_shell.md b/doc/code/front_end/2_pyrit_shell.md new file mode 100644 index 000000000..c40a0229a --- /dev/null +++ b/doc/code/front_end/2_pyrit_shell.md @@ -0,0 +1,184 @@ +# PyRIT Shell - Interactive Command Line + +PyRIT Shell provides an interactive REPL (Read-Eval-Print Loop) for running AI red teaming scenarios with fast execution and session-based result tracking. + +## Quick Start + +Start the shell: + +```bash +pyrit_shell +``` + +With startup options: + +```bash +# Set default database for all runs +pyrit_shell --database InMemory + +# Set default log level +pyrit_shell --log-level DEBUG + +# Load initializers at startup +pyrit_shell --initializers openai_objective_target + +# Load custom initialization scripts +pyrit_shell --initialization-scripts ./my_config.py +``` + +## Available Commands + +Once in the shell, you have access to: + +| Command | Description | +|---------|-------------| +| `list-scenarios` | List all available scenarios | +| `list-initializers` | List all available initializers | +| `run [options]` | Run a scenario with optional parameters | +| `scenario-history` | List all previous scenario runs in this session | +| `print-scenario [N]` | Print detailed results for scenario run(s) | +| `help [command]` | Show help for a command | +| `clear` | Clear the screen | +| `exit` (or `quit`, `q`) | Exit the shell | + +## Running Scenarios + +The `run` command executes scenarios with the same options as `pyrit_scan`: + +### Basic Usage + +```bash +pyrit> run foundry_scenario --initializers openai_objective_target +``` + +### With Strategies + +```bash +pyrit> run encoding_scenario --initializers openai_objective_target --strategies base64 rot13 + +pyrit> run foundry_scenario --initializers openai_objective_target -s jailbreak crescendo +``` + +### With Runtime Parameters + +```bash +# Set concurrency and retries +pyrit> run foundry_scenario --initializers openai_objective_target --max-concurrency 10 --max-retries 3 + +# Add memory labels for tracking +pyrit> run encoding_scenario --initializers openai_objective_target --memory-labels '{\"experiment\":\"test1\",\"version\":\"v2\"}' +``` + +### Override Defaults Per-Run + +```bash +# Override database and log level for this run only +pyrit> run encoding_scenario --initializers openai_objective_target --database InMemory --log-level DEBUG +``` + +### Run Command Options + +``` +--initializers ... Built-in initializers to run before the scenario (REQUIRED) +--initialization-scripts <...> Custom Python scripts to run before the scenario (alternative) +--strategies, -s ... Strategy names to use +--max-concurrency Maximum concurrent operations +--max-retries Maximum retry attempts +--memory-labels JSON string of labels +--database Override default database (InMemory, SQLite, AzureSQL) +--log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) +``` + +## Session History + +Track and review all scenario runs in your session: + +```bash +# Show all runs from this session +pyrit> scenario-history + +# Print details of the most recent run +pyrit> print-scenario + +# Print details of a specific run (by number from history) +pyrit> print-scenario 1 + +# Print all runs +pyrit> print-scenario +``` + +Example output: + +``` +pyrit> scenario-history + +Scenario Run History: +================================================================================ +1) foundry_scenario --initializers openai_objective_target --strategies base64 +2) encoding_scenario --initializers openai_objective_target --strategies rot13 +3) foundry_scenario --initializers openai_objective_target -s jailbreak +================================================================================ + +Total runs: 3 + +Use 'print-scenario ' to view detailed results for a specific run. +``` + +## Interactive Exploration + +The shell excels at interactive testing workflows: + +```bash +# Start shell with defaults +pyrit_shell --database InMemory --initializers openai_objective_target + +# Quick exploration +pyrit> list-scenarios +pyrit> run encoding_scenario --strategies base64 +pyrit> run encoding_scenario --strategies rot13 +pyrit> run encoding_scenario --strategies morse_code + +# Review and compare +pyrit> scenario-history +pyrit> print-scenario 1 +pyrit> print-scenario 2 +``` + +## Shell Benefits + +- **Fast Execution**: PyRIT modules load once at startup (typically 5-10 seconds), making subsequent commands instant +- **Session Tracking**: All runs are stored in history for easy comparison +- **Interactive Workflow**: Perfect for iterative testing and debugging +- **Persistent Context**: Default settings apply across multiple runs +- **Tab Completion**: Command and argument completion (if supported by your terminal) + +## Tips + +1. **Set defaults at startup** to avoid repeating options: + ```bash + pyrit_shell --database InMemory --log-level INFO + ``` + +2. **Use short strategy aliases** with `-s`: + ```bash + pyrit> run foundry_scenario --initializers openai_objective_target -s base64 rot13 + ``` + +3. **Review history regularly** to track what you've tested: + ```bash + pyrit> scenario-history + ``` + +4. **Print specific results** to compare outcomes: + ```bash + pyrit> print-scenario 1 # baseline run + pyrit> print-scenario 3 # modified run + ``` + +## Exit the Shell + +```bash +pyrit> exit +``` + +Or use `quit` or `q`. diff --git a/pyproject.toml b/pyproject.toml index c53d22be5..158c8b111 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,7 +144,8 @@ all = [ ] [project.scripts] -pyrit_scan = "pyrit.cli.__main__:main" +pyrit_scan = "pyrit.cli.pyrit_scan:main" +pyrit_shell = "pyrit.cli.pyrit_shell:main" [tool.pytest.ini_options] pythonpath = ["."] diff --git a/pyrit/cli/__main__.py b/pyrit/cli/__main__.py deleted file mode 100644 index fcdfe2d27..000000000 --- a/pyrit/cli/__main__.py +++ /dev/null @@ -1,511 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -""" -PyRIT CLI - Command-line interface for running security scenarios. - -This module provides the main entry point for the pyrit_scan command. -It supports running scenarios with configurable database backends and -initialization scripts. -""" - -# Standard library imports -import asyncio -import logging -import sys -from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter -from pathlib import Path -from typing import Any - -# Local application imports -from pyrit.cli.initializer_registry import InitializerInfo, InitializerRegistry -from pyrit.cli.scenario_registry import ScenarioRegistry -from pyrit.common.path import PYRIT_PATH -from pyrit.scenarios import Scenario -from pyrit.scenarios.printer.console_printer import ConsoleScenarioResultPrinter -from pyrit.setup import AZURE_SQL, IN_MEMORY, SQLITE, initialize_pyrit - -logger = logging.getLogger(__name__) - - -def parse_args(args=None) -> Namespace: - """ - Parse command-line arguments for the PyRIT scanner. - - Args: - args: Command-line arguments to parse. If None, uses sys.argv. - - Returns: - Namespace: Parsed arguments. - """ - parser = ArgumentParser( - prog="pyrit_scan", - description="""PyRIT Scanner - Run security scenarios against AI systems - -Examples: - # List available scenarios and initializers - pyrit_scan --list-scenarios - pyrit_scan --list-initializers - - # Run a scenario with built-in initializers - pyrit_scan foundry_scenario --initializers simple objective_target - - # Run with custom initialization scripts - pyrit_scan encoding_scenario --initialization-scripts ./my_config.py - - # Run specific strategies - pyrit_scan encoding_scenario --initializers simple objective_target --scenario-strategies base64 rot13 morse_code - pyrit_scan foundry_scenario --initializers simple objective_target --scenario-strategies base64 atbash -""", - formatter_class=RawDescriptionHelpFormatter, - ) - - # Global options - parser.add_argument( - "--verbose", - action="store_true", - help="Enable verbose logging output", - ) - - # Discovery/help options - parser.add_argument( - "--list-scenarios", - action="store_true", - help="List all available scenarios and exit", - ) - - parser.add_argument( - "--list-initializers", - action="store_true", - help="List all available scenario initializers and exit", - ) - - # Scenario name as positional argument - parser.add_argument( - "scenario_name", - type=str, - nargs="?", - help="Name of the scenario to run (e.g., encoding_scenario, foundry_scenario)", - ) - - # Scenario options - parser.add_argument( - "--database", - type=str, - choices=[IN_MEMORY, SQLITE, AZURE_SQL], - default=SQLITE, - help="Database type to use for memory storage", - ) - parser.add_argument( - "--initializers", - type=str, - nargs="+", - help="Built-in initializer names (e.g., simple, objective_target, objective_list)", - ) - parser.add_argument( - "--initialization-scripts", - type=str, - nargs="+", - help="Paths to custom Python initialization scripts that configure scenarios and defaults", - ) - parser.add_argument( - "--scenario-strategies", - type=str, - nargs="+", - help="List of strategy names to run (e.g., base64 rot13 morse_code). If not specified, uses scenario defaults.", - ) - - return parser.parse_args(args) - - -def list_scenarios(*, registry: ScenarioRegistry) -> None: - """ - Print all available scenarios to the console. - - Args: - registry (ScenarioRegistry): The scenario registry to query. - """ - scenarios = registry.list_scenarios() - - if not scenarios: - print("No scenarios found.") - return - - print("\nAvailable Scenarios:") - print("=" * 80) - - for scenario_info in scenarios: - _format_scenario_info(scenario_info=scenario_info) - - print("\n" + "=" * 80) - print(f"\nTotal scenarios: {len(scenarios)}") - print("\nFor usage information, use: pyrit_scan --help") - - -def list_initializers() -> None: - """ - Print all available scenario initializers to the console. - - Discovers initializers from pyrit/setup/initializers/scenarios directory only. - Note that other initializers (like 'simple' from the base directory) can still - be used with --initializers, they just won't appear in this listing. - """ - scenarios_path = Path(PYRIT_PATH) / "setup" / "initializers" / "scenarios" - registry = InitializerRegistry(discovery_path=scenarios_path) - initializer_info = registry.list_initializers() - - if not initializer_info: - print("No initializers found.") - return - - print("\nAvailable Scenario Initializers:") - print("=" * 80) - - for info in initializer_info: - _format_initializer_info(info=info) - - print("\n" + "=" * 80) - print(f"\nTotal initializers: {len(initializer_info)}") - print("\nFor usage information, use: pyrit_scan --help") - - -async def run_scenario_async( - *, - scenario_name: str, - registry: ScenarioRegistry, - scenario_strategies: list[str] | None = None, -) -> None: - """ - Run a specific scenario by name. - - Args: - scenario_name (str): Name of the scenario to run. - registry (ScenarioRegistry): The scenario registry to query. - scenario_strategies (list[str] | None): Optional list of strategy names to run. - If provided, these will be converted to ScenarioCompositeStrategy instances - and passed to the scenario's __init__. - - Raises: - ValueError: If the scenario is not found or cannot be instantiated. - """ - # Get the scenario class - scenario_class = registry.get_scenario(scenario_name) - - if scenario_class is None: - available_scenarios = ", ".join(registry.get_scenario_names()) - raise ValueError( - f"Scenario '{scenario_name}' not found.\n" - f"Available scenarios: {available_scenarios}\n" - f"Use 'pyrit_scan --list-scenarios' to see all available scenarios." - ) - - logger.info(f"Instantiating scenario: {scenario_class.__name__}") - - # Instantiate the scenario with optional strategy override - # The scenario should get its configuration from: - # 1. --scenario-strategies CLI flag (if provided) - # 2. Default values set by initialization scripts via @apply_defaults - # 3. Global variables set by initialization scripts - - scenario: Scenario - - try: - # Instantiate the scenario (no strategies in __init__) - scenario = scenario_class() # type: ignore[call-arg] - except TypeError as e: - # Check if this is a missing parameter error - error_msg = str(e) - if "missing" in error_msg.lower() and "required" in error_msg.lower(): - raise ValueError( - f"Failed to instantiate scenario '{scenario_name}'.\n" - f"The scenario requires parameters that are not configured.\n\n" - f"To fix this, provide an initialization script with --initialization-scripts that:\n" - f"1. Sets default values for the required parameters using set_default_value()\n" - f"2. Or sets global variables that can be used by the scenario\n\n" - f"Original error: {error_msg}" - ) from e - else: - raise ValueError(f"Failed to instantiate scenario '{scenario_name}'.\n" f"Error: {error_msg}") from e - except Exception as e: - raise ValueError( - f"Failed to instantiate scenario '{scenario_name}'.\n" - f"Make sure your initialization scripts properly configure all required parameters.\n" - f"Error: {e}" - ) from e - - logger.info(f"Initializing scenario: {scenario.name}") - - # Convert strategy names to enum instances if provided - strategy_enums = None - if scenario_strategies: - strategy_class = scenario_class.get_strategy_class() - strategy_enums = [] - for strategy_name in scenario_strategies: - try: - strategy_enums.append(strategy_class(strategy_name)) - except ValueError: - available_strategies = [s.value for s in strategy_class] - raise ValueError( - f"Strategy '{strategy_name}' not found in {scenario_class.__name__}.\n" - f"Available strategies: {', '.join(available_strategies)}\n" - f"Use 'pyrit_scan --list-scenarios' to see available strategies for each scenario." - ) from None - - # Initialize the scenario with strategy override if provided - # Note: objective_target and other parameters should be set via initialization scripts - # using set_default_value() which will be applied by the @apply_defaults decorator - try: - if strategy_enums: - await scenario.initialize_async(scenario_strategies=strategy_enums) # type: ignore[call-arg] - else: - await scenario.initialize_async() # type: ignore[call-arg] - except Exception as e: - raise ValueError(f"Failed to initialize scenario '{scenario_name}'.\n" f"Error: {e}") from e - - logger.info(f"Running scenario: {scenario.name} with {scenario.atomic_attack_count} atomic attacks") - - # Run the scenario - try: - result = await scenario.run_async() - - # Print results using ConsoleScenarioResultPrinter - printer = ConsoleScenarioResultPrinter() - await printer.print_summary_async(result) - - except Exception as e: - raise ValueError(f"Failed to run scenario '{scenario_name}'.\n" f"Error: {e}") from e - - -def _format_wrapped_text(*, text: str, indent: str = " ", max_width: int = 80) -> None: - """ - Print text with word wrapping at the specified indentation level. - - Args: - text (str): The text to wrap and print. - indent (str): The indentation string to use for each line. - max_width (int): The maximum line width including indentation. - """ - words = text.split() - line = indent - - for word in words: - if len(line) + len(word) + 1 > max_width: - print(line) - line = indent + word - else: - if line == indent: - line += word - else: - line += " " + word - - if line.strip(): - print(line) - - -def _format_scenario_info(*, scenario_info: dict[str, Any]) -> None: - """ - Print formatted information about a scenario. - - Args: - scenario_info (dict[str, Any]): Dictionary containing scenario information. - Expected keys: name, class_name, description, default_strategy, aggregate_strategies, all_strategies - """ - print(f"\n {scenario_info['name']}") - print(f" Class: {scenario_info['class_name']}") - print(" Description:") - _format_wrapped_text(text=scenario_info["description"], indent=" ") - - # Display aggregate strategies if present - if scenario_info.get("aggregate_strategies"): - print(" Aggregate Strategies:") - for strategy in scenario_info["aggregate_strategies"]: - print(f" - {strategy}") - - # Display all strategies if present - if scenario_info.get("all_strategies"): - print(f" Available Strategies ({len(scenario_info['all_strategies'])}):") - strategies_text = ", ".join(scenario_info["all_strategies"]) - _format_wrapped_text(text=strategies_text, indent=" ") - - # Display default strategy if present - if scenario_info.get("default_strategy"): - print(f" Default Strategy: {scenario_info['default_strategy']}") - - -def _format_initializer_info(*, info: InitializerInfo) -> None: - """ - Print formatted information about an initializer. - - Args: - info (InitializerInfo): Dictionary containing initializer information. - """ - print(f"\n {info['name']}") - print(f" Class: {info['class_name']}") - print(f" Name: {info['initializer_name']}") - print(f" Execution Order: {info['execution_order']}") - - if info["required_env_vars"]: - print(" Required Environment Variables:") - for env_var in info["required_env_vars"]: - print(f" - {env_var}") - else: - print(" Required Environment Variables: None") - - print(" Description:") - _format_wrapped_text(text=info["description"], indent=" ") - - -def _collect_initialization_scripts(*, parsed_args: Namespace) -> list[Path] | None: - """ - Collect all initialization scripts from built-in initializers and custom scripts. - - Args: - parsed_args (Namespace): Parsed command-line arguments. - - Returns: - list[Path] | None: List of script paths, or None if no scripts. - - Raises: - ValueError: If initializer lookup fails. - FileNotFoundError: If custom script path does not exist. - """ - initialization_scripts = [] - - # Handle built-in initializers - if hasattr(parsed_args, "initializers") and parsed_args.initializers: - registry = InitializerRegistry() - paths = registry.resolve_initializer_paths(initializer_names=parsed_args.initializers) - initialization_scripts.extend(paths) - - # Handle custom initialization scripts - if hasattr(parsed_args, "initialization_scripts") and parsed_args.initialization_scripts: - paths = InitializerRegistry.resolve_script_paths(script_paths=parsed_args.initialization_scripts) - initialization_scripts.extend(paths) - - return initialization_scripts if initialization_scripts else None - - -def _handle_list_scenarios(*, parsed_args: Namespace) -> int: - """ - Handle the --list-scenarios flag. - - Args: - parsed_args (Namespace): Parsed command-line arguments. - - Returns: - int: Exit code (0 for success, 1 for error). - """ - initialization_scripts = None - - if hasattr(parsed_args, "initialization_scripts") and parsed_args.initialization_scripts: - try: - initialization_scripts = InitializerRegistry.resolve_script_paths( - script_paths=parsed_args.initialization_scripts - ) - - database = parsed_args.database if hasattr(parsed_args, "database") else SQLITE - initialize_pyrit( - memory_db_type=database, - initialization_scripts=initialization_scripts, - ) - except FileNotFoundError as e: - print(f"Error: {e}") - return 1 - - registry = ScenarioRegistry() - if initialization_scripts: - registry.discover_user_scenarios() - list_scenarios(registry=registry) - return 0 - - -def _run_scenario(*, parsed_args: Namespace) -> int: - """ - Run the specified scenario with the given configuration. - - Args: - parsed_args (Namespace): Parsed command-line arguments. - - Returns: - int: Exit code (0 for success, 1 for error). - """ - try: - # Collect all initialization scripts - initialization_scripts = _collect_initialization_scripts(parsed_args=parsed_args) - - # Initialize PyRIT - database = parsed_args.database if hasattr(parsed_args, "database") else SQLITE - logger.info(f"Initializing PyRIT with database type: {database}") - - initialize_pyrit( - memory_db_type=database, - initialization_scripts=initialization_scripts, - ) - - # Create scenario registry and discover user scenarios - registry = ScenarioRegistry() - registry.discover_user_scenarios() - - # Get scenario strategies from CLI args if provided - scenario_strategies = parsed_args.scenario_strategies if hasattr(parsed_args, "scenario_strategies") else None - - # Run the scenario - asyncio.run( - run_scenario_async( - scenario_name=parsed_args.scenario_name, - registry=registry, - scenario_strategies=scenario_strategies, - ) - ) - return 0 - - except (ValueError, FileNotFoundError) as e: - print(f"Error: {e}") - return 1 - except Exception as e: - logger.error(f"Error: {e}", exc_info=True) - print(f"\nError: {e}") - return 1 - - -def main(args=None) -> int: - """ - Main entry point for the PyRIT scanner CLI. - - Args: - args: Command-line arguments. If None, uses sys.argv. - - Returns: - int: Exit code (0 for success, 1 for error). - """ - # Parse arguments - try: - parsed_args = parse_args(args) - except SystemExit as e: - return e.code if isinstance(e.code, int) else 1 - - # Configure logging - if parsed_args.verbose: - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - else: - logging.basicConfig(level=logging.WARNING, format="%(levelname)s: %(message)s") - - # Handle discovery flags - if parsed_args.list_scenarios: - return _handle_list_scenarios(parsed_args=parsed_args) - - if parsed_args.list_initializers: - list_initializers() - return 0 - - # Verify scenario was provided - if not parsed_args.scenario_name: - print("Error: No scenario specified. Use --help for usage information.") - return 1 - - # Run the scenario - return _run_scenario(parsed_args=parsed_args) - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py new file mode 100644 index 000000000..e84f6f9e3 --- /dev/null +++ b/pyrit/cli/frontend_core.py @@ -0,0 +1,762 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Shared core logic for PyRIT Frontends. + +This module contains all the business logic for: +- Loading and discovering scenarios +- Running scenarios +- Formatting output +- Managing initialization scripts + +Both pyrit_scan and pyrit_shell use these functions. +""" + +from __future__ import annotations + +import json +import logging +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, TypedDict + +try: + import termcolor # type: ignore + + HAS_TERMCOLOR = True +except ImportError: + HAS_TERMCOLOR = False + + # Create a dummy termcolor module for fallback + class termcolor: # type: ignore + @staticmethod + def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: ignore + print(text) + + +if TYPE_CHECKING: + from pyrit.cli.initializer_registry import InitializerInfo, InitializerRegistry + from pyrit.cli.scenario_registry import ScenarioRegistry + from pyrit.models.scenario_result import ScenarioResult + +logger = logging.getLogger(__name__) + +# Database type constants +IN_MEMORY = "InMemory" +SQLITE = "SQLite" +AZURE_SQL = "AzureSQL" + + +class ScenarioInfo(TypedDict): + """Type definition for scenario information dictionary.""" + + name: str + class_name: str + description: str + default_strategy: str + all_strategies: list[str] + aggregate_strategies: list[str] + + +class FrontendCore: + """ + Shared context for PyRIT operations. + + This object holds all the registries and configuration needed to run + scenarios. It can be created once (for shell) or per-command (for CLI). + """ + + def __init__( + self, + *, + database: str = SQLITE, + initialization_scripts: Optional[list[Path]] = None, + initializer_names: Optional[list[str]] = None, + log_level: str = "WARNING", + ): + """ + Initialize PyRIT context. + + Args: + database: Database type (InMemory, SQLite, or AzureSQL). + initialization_scripts: Optional list of initialization script paths. + initializer_names: Optional list of built-in initializer names to run. + log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Defaults to WARNING. + + Raises: + ValueError: If database or log_level are invalid. + """ + # Validate inputs + self._database = validate_database(database=database) + self._initialization_scripts = initialization_scripts + self._initializer_names = initializer_names + self._log_level = validate_log_level(log_level=log_level) + + # Lazy-loaded registries + self._scenario_registry: Optional[ScenarioRegistry] = None + self._initializer_registry: Optional[InitializerRegistry] = None + self._initialized = False + + # Configure logging + logging.basicConfig(level=getattr(logging, self._log_level)) + + def initialize(self) -> None: + """Initialize PyRIT and load registries (heavy operation).""" + if self._initialized: + return + + print("Loading PyRIT modules...") + sys.stdout.flush() + + from pyrit.cli.initializer_registry import InitializerRegistry + from pyrit.cli.scenario_registry import ScenarioRegistry + from pyrit.setup import initialize_pyrit + + # Initialize PyRIT without initializers (they run per-scenario) + initialize_pyrit( + memory_db_type=self._database, + initialization_scripts=None, + initializers=None, + ) + + # Load registries + self._scenario_registry = ScenarioRegistry() + if self._initialization_scripts: + print("Discovering user scenarios...") + sys.stdout.flush() + self._scenario_registry.discover_user_scenarios() + + self._initializer_registry = InitializerRegistry() + + self._initialized = True + + @property + def scenario_registry(self) -> "ScenarioRegistry": + """Get the scenario registry (initializes if needed).""" + if not self._initialized: + self.initialize() + assert self._scenario_registry is not None + return self._scenario_registry + + @property + def initializer_registry(self) -> "InitializerRegistry": + """Get the initializer registry (initializes if needed).""" + if not self._initialized: + self.initialize() + assert self._initializer_registry is not None + return self._initializer_registry + + +def list_scenarios(*, context: FrontendCore) -> list[ScenarioInfo]: + """ + List all available scenarios. + + Args: + context: PyRIT context with loaded registries. + + Returns: + List of scenario info dictionaries. + """ + return context.scenario_registry.list_scenarios() + + +def list_initializers(*, context: FrontendCore, discovery_path: Optional[Path] = None) -> "Sequence[InitializerInfo]": + """ + List all available initializers. + + Args: + context: PyRIT context with loaded registries. + discovery_path: Optional path to discover initializers from. + + Returns: + Sequence of initializer info dictionaries. + """ + if discovery_path: + from pyrit.cli.initializer_registry import InitializerRegistry + + registry = InitializerRegistry(discovery_path=discovery_path) + return registry.list_initializers() + return context.initializer_registry.list_initializers() + + +async def run_scenario_async( + *, + scenario_name: str, + context: FrontendCore, + scenario_strategies: Optional[list[str]] = None, + max_concurrency: Optional[int] = None, + max_retries: Optional[int] = None, + memory_labels: Optional[dict[str, str]] = None, + print_summary: bool = True, +) -> "ScenarioResult": + """ + Run a scenario by name. + + Args: + scenario_name: Name of the scenario to run. + context: PyRIT context with loaded registries. + scenario_strategies: Optional list of strategy names. + max_concurrency: Max concurrent operations. + max_retries: Max retry attempts. + memory_labels: Labels to attach to memory entries. + print_summary: Whether to print the summary after execution. Defaults to True. + + Returns: + ScenarioResult: The result of the scenario execution. + + Raises: + ValueError: If scenario not found or fails to run. + + Note: + Initializers from PyRITContext will be run before the scenario executes. + """ + from pyrit.scenarios.printer.console_printer import ConsoleScenarioResultPrinter + from pyrit.setup import initialize_pyrit + + # Ensure context is initialized first (loads registries) + # This must happen BEFORE we run initializers to avoid double-initialization + if not context._initialized: + context.initialize() + + # Run initializers before scenario + initializer_instances = None + if context._initializer_names: + print(f"Running {len(context._initializer_names)} initializer(s)...") + sys.stdout.flush() + + initializer_instances = [] + + for name in context._initializer_names: + initializer_class = context.initializer_registry.get_initializer_class(name=name) + initializer_instances.append(initializer_class()) + + # Re-initialize PyRIT with the scenario-specific initializers + # This resets memory and applies initializer defaults + initialize_pyrit( + memory_db_type=context._database, + initialization_scripts=context._initialization_scripts, + initializers=initializer_instances, + ) + + # Get scenario class + scenario_class = context.scenario_registry.get_scenario(scenario_name) + + if scenario_class is None: + available = ", ".join(context.scenario_registry.get_scenario_names()) + raise ValueError(f"Scenario '{scenario_name}' not found.\n" f"Available scenarios: {available}") + + # Build initialization kwargs (these go to initialize_async, not __init__) + init_kwargs: dict[str, Any] = {} + + if scenario_strategies: + strategy_class = scenario_class.get_strategy_class() + strategy_enums = [] + for name in scenario_strategies: + try: + strategy_enums.append(strategy_class(name)) + except ValueError: + available_strategies = [s.value for s in strategy_class] + raise ValueError( + f"Strategy '{name}' not found for scenario '{scenario_name}'. " + f"Available: {', '.join(available_strategies)}" + ) from None + init_kwargs["scenario_strategies"] = strategy_enums + + if max_concurrency is not None: + init_kwargs["max_concurrency"] = max_concurrency + if max_retries is not None: + init_kwargs["max_retries"] = max_retries + if memory_labels is not None: + init_kwargs["memory_labels"] = memory_labels + + # Instantiate and run + print(f"\nRunning scenario: {scenario_name}") + sys.stdout.flush() + + # Scenarios here are a concrete subclass + # Runtime parameters are passed to initialize_async() + scenario = scenario_class() # type: ignore[call-arg] + await scenario.initialize_async(**init_kwargs) + result = await scenario.run_async() + + # Print results if requested + if print_summary: + printer = ConsoleScenarioResultPrinter() + await printer.print_summary_async(result) + + return result + + +def _format_wrapped_text(*, text: str, indent: str, width: int = 78) -> str: + """ + Format text with word wrapping. + + Args: + text: Text to wrap. + indent: Indentation string for wrapped lines. + width: Maximum line width. Defaults to 78. + + Returns: + Formatted text with line breaks. + """ + words = text.split() + lines = [] + current_line = "" + + for word in words: + if not current_line: + current_line = word + elif len(current_line) + len(word) + 1 + len(indent) <= width: + current_line += " " + word + else: + lines.append(indent + current_line) + current_line = word + + if current_line: + lines.append(indent + current_line) + + return "\n".join(lines) + + +def _print_header(*, text: str) -> None: + """ + Print a colored header if termcolor is available. + + Args: + text: Header text to print. + """ + if HAS_TERMCOLOR: + termcolor.cprint(f"\n {text}", "cyan", attrs=["bold"]) + else: + print(f"\n {text}") + + +def format_scenario_info(*, scenario_info: ScenarioInfo) -> None: + """ + Print formatted information about a scenario. + + Args: + scenario_info: Dictionary containing scenario information. + """ + _print_header(text=scenario_info["name"]) + print(f" Class: {scenario_info['class_name']}") + + description = scenario_info.get("description", "") + if description: + print(" Description:") + print(_format_wrapped_text(text=description, indent=" ")) + + if scenario_info.get("aggregate_strategies"): + agg_strategies = scenario_info["aggregate_strategies"] + print(" Aggregate Strategies:") + formatted = _format_wrapped_text(text=", ".join(agg_strategies), indent=" - ") + print(formatted) + + if scenario_info.get("all_strategies"): + strategies = scenario_info["all_strategies"] + print(f" Available Strategies ({len(strategies)}):") + formatted = _format_wrapped_text(text=", ".join(strategies), indent=" ") + print(formatted) + + if scenario_info.get("default_strategy"): + print(f" Default Strategy: {scenario_info['default_strategy']}") + + +def format_initializer_info(*, initializer_info: "InitializerInfo") -> None: + """ + Print formatted information about an initializer. + + Args: + initializer_info: Dictionary containing initializer information. + """ + _print_header(text=initializer_info["name"]) + print(f" Class: {initializer_info['class_name']}") + print(f" Name: {initializer_info['initializer_name']}") + print(f" Execution Order: {initializer_info['execution_order']}") + + if initializer_info.get("required_env_vars"): + print(" Required Environment Variables:") + for env_var in initializer_info["required_env_vars"]: + print(f" - {env_var}") + else: + print(" Required Environment Variables: None") + + if initializer_info.get("description"): + print(" Description:") + print(_format_wrapped_text(text=initializer_info["description"], indent=" ")) + + +def validate_database(*, database: str) -> str: + """ + Validate database type. + + Args: + database: Database type string. + + Returns: + Validated database type. + + Raises: + ValueError: If database type is invalid. + """ + valid_databases = [IN_MEMORY, SQLITE, AZURE_SQL] + if database not in valid_databases: + raise ValueError(f"Invalid database type: {database}. " f"Must be one of: {', '.join(valid_databases)}") + return database + + +def validate_log_level(*, log_level: str) -> str: + """ + Validate log level. + + Args: + log_level: Log level string (case-insensitive). + + Returns: + Validated log level in uppercase. + + Raises: + ValueError: If log level is invalid. + """ + valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + level_upper = log_level.upper() + if level_upper not in valid_levels: + raise ValueError(f"Invalid log level: {log_level}. " f"Must be one of: {', '.join(valid_levels)}") + return level_upper + + +def validate_integer(value: str, *, name: str = "value", min_value: Optional[int] = None) -> int: + """ + Validate and parse an integer value. + + Note: The 'value' parameter is positional (not keyword-only) to allow use with + argparse lambdas like: lambda v: validate_integer(v, min_value=1). + This is an exception to the PyRIT style guide for argparse compatibility. + + Args: + value: String value to parse. + name: Parameter name for error messages. Defaults to "value". + min_value: Optional minimum value constraint. + + Returns: + Parsed integer. + + Raises: + ValueError: If value is not a valid integer or violates constraints. + """ + # Reject boolean types explicitly (int(True) == 1, int(False) == 0) + if isinstance(value, bool): + raise ValueError(f"{name} must be an integer string, got boolean: {value}") + + # Ensure value is a string + if not isinstance(value, str): + raise ValueError(f"{name} must be a string, got {type(value).__name__}: {value}") + + # Strip whitespace and validate it looks like an integer + value = value.strip() + if not value: + raise ValueError(f"{name} cannot be empty") + + try: + int_value = int(value) + except (ValueError, TypeError) as e: + raise ValueError(f"{name} must be an integer, got: {value}") from e + + if min_value is not None and int_value < min_value: + raise ValueError(f"{name} must be at least {min_value}, got: {int_value}") + + return int_value + + +def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], Any]: + """ + Decorator to convert ValueError to argparse.ArgumentTypeError. + + This decorator adapts our keyword-only validators for use with argparse's type= parameter. + It handles two challenges: + + 1. Exception Translation: argparse expects ArgumentTypeError, but our validators raise + ValueError. This decorator catches ValueError and re-raises as ArgumentTypeError. + + 2. Keyword-Only Parameters: PyRIT validators use keyword-only parameters (e.g., + validate_database(*, database: str)), but argparse's type= passes a positional argument. + This decorator inspects the function signature and calls the validator with the correct + keyword argument name. + + This pattern allows us to: + - Keep validators as pure functions with proper type hints + - Follow PyRIT style guide (keyword-only parameters) + - Reuse the same validation logic in both argparse and non-argparse contexts + + Args: + validator_func: Function that raises ValueError on invalid input. + Must have at least one parameter (can be keyword-only). + + Returns: + Wrapped function that: + - Accepts a single positional argument (for argparse compatibility) + - Calls validator_func with the correct keyword argument + - Raises ArgumentTypeError instead of ValueError + """ + import inspect + + # Get the first parameter name from the function signature + sig = inspect.signature(validator_func) + params = list(sig.parameters.keys()) + if not params: + raise ValueError(f"Validator function {validator_func.__name__} must have at least one parameter") + first_param = params[0] + + def wrapper(value): + import argparse as ap + + try: + # Call with keyword argument to support keyword-only parameters + return validator_func(**{first_param: value}) + except ValueError as e: + raise ap.ArgumentTypeError(str(e)) from e + + # Preserve function metadata for better debugging + wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator") + wrapper.__doc__ = getattr(validator_func, "__doc__", None) + return wrapper + + +# Argparse-compatible validators +# +# These wrappers adapt our core validators (which use keyword-only parameters and raise +# ValueError) for use with argparse's type= parameter (which passes positional arguments +# and expects ArgumentTypeError). +# +# Pattern: +# - Use core validators (validate_database, validate_log_level, etc.) in regular code +# - Use these _argparse versions ONLY in parser.add_argument(..., type=...) +# +# The lambda wrappers for validate_integer are necessary because we need to partially +# apply the min_value parameter while still allowing the decorator to work correctly. +validate_database_argparse = _argparse_validator(validate_database) +validate_log_level_argparse = _argparse_validator(validate_log_level) +positive_int = _argparse_validator(lambda v: validate_integer(v, min_value=1)) +non_negative_int = _argparse_validator(lambda v: validate_integer(v, min_value=0)) + + +def parse_memory_labels(json_string: str) -> dict[str, str]: + """ + Parse memory labels from a JSON string. + + Args: + json_string: JSON string containing label key-value pairs. + + Returns: + Dictionary of labels. + + Raises: + ValueError: If JSON is invalid or contains non-string values. + """ + try: + labels = json.loads(json_string) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON for memory labels: {e}") from e + + if not isinstance(labels, dict): + raise ValueError("Memory labels must be a JSON object (dictionary)") + + # Validate all keys and values are strings + for key, value in labels.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise ValueError(f"All label keys and values must be strings. Got: {key}={value}") + + return labels + + +def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]: + """ + Resolve initialization script paths. + + Args: + script_paths: List of script path strings. + + Returns: + List of resolved Path objects. + + Raises: + FileNotFoundError: If a script path does not exist. + """ + from pyrit.cli.initializer_registry import InitializerRegistry + + return InitializerRegistry.resolve_script_paths(script_paths=script_paths) + + +def get_default_initializer_discovery_path() -> Path: + """ + Get the default path for discovering initializers. + + Returns: + Path to the scenarios initializers directory. + """ + PYRIT_PATH = Path(__file__).parent.parent.resolve() + return PYRIT_PATH / "setup" / "initializers" / "scenarios" + + +def print_scenarios_list(*, context: FrontendCore) -> int: + """ + Print a formatted list of all available scenarios. + + Args: + context: PyRIT context with loaded registries. + + Returns: + Exit code (0 for success). + """ + scenarios = list_scenarios(context=context) + + if not scenarios: + print("No scenarios found.") + return 0 + + print("\nAvailable Scenarios:") + print("=" * 80) + for scenario_info in scenarios: + format_scenario_info(scenario_info=scenario_info) + print("\n" + "=" * 80) + print(f"\nTotal scenarios: {len(scenarios)}") + return 0 + + +def print_initializers_list(*, context: FrontendCore, discovery_path: Optional[Path] = None) -> int: + """ + Print a formatted list of all available initializers. + + Args: + context: PyRIT context with loaded registries. + discovery_path: Optional path to discover initializers from. + + Returns: + Exit code (0 for success). + """ + initializers = list_initializers(context=context, discovery_path=discovery_path) + + if not initializers: + print("No initializers found.") + return 0 + + print("\nAvailable Initializers:") + print("=" * 80) + for initializer_info in initializers: + format_initializer_info(initializer_info=initializer_info) + print("\n" + "=" * 80) + print(f"\nTotal initializers: {len(initializers)}") + return 0 + + +# Shared argument help text +ARG_HELP = { + "initializers": "Built-in initializer names to run before the scenario (e.g., openai_objective_target)", + "initialization_scripts": "Paths to custom Python initialization scripts to run before the scenario", + "scenario_strategies": "List of strategy names to run (e.g., base64 rot13)", + "max_concurrency": "Maximum number of concurrent attack executions (must be >= 1)", + "max_retries": "Maximum number of automatic retries on exception (must be >= 0)", + "memory_labels": 'Additional labels as JSON string (e.g., \'{"experiment": "test1"}\')', + "database": "Database type to use for memory storage", + "log_level": "Logging level", +} + + +def parse_run_arguments(*, args_string: str) -> dict[str, Any]: + """ + Parse run command arguments from a string (for shell mode). + + Args: + args_string: Space-separated argument string (e.g., "scenario_name --initializers foo --strategies bar"). + + Returns: + Dictionary with parsed arguments: + - scenario_name: str + - initializers: Optional[list[str]] + - initialization_scripts: Optional[list[str]] + - scenario_strategies: Optional[list[str]] + - max_concurrency: Optional[int] + - max_retries: Optional[int] + - memory_labels: Optional[dict[str, str]] + - database: Optional[str] + - log_level: Optional[str] + + Raises: + ValueError: If parsing or validation fails. + """ + parts = args_string.split() + + if not parts: + raise ValueError("No scenario name provided") + + result: dict[str, Any] = { + "scenario_name": parts[0], + "initializers": None, + "initialization_scripts": None, + "scenario_strategies": None, + "max_concurrency": None, + "max_retries": None, + "memory_labels": None, + "database": None, + "log_level": None, + } + + i = 1 + while i < len(parts): + if parts[i] == "--initializers": + # Collect initializers until next flag + result["initializers"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initializers"].append(parts[i]) + i += 1 + elif parts[i] == "--initialization-scripts": + # Collect script paths until next flag + result["initialization_scripts"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--"): + result["initialization_scripts"].append(parts[i]) + i += 1 + elif parts[i] in ("--strategies", "-s"): + # Collect strategies until next flag + result["scenario_strategies"] = [] + i += 1 + while i < len(parts) and not parts[i].startswith("--") and parts[i] != "-s": + result["scenario_strategies"].append(parts[i]) + i += 1 + elif parts[i] == "--max-concurrency": + i += 1 + if i >= len(parts): + raise ValueError("--max-concurrency requires a value") + result["max_concurrency"] = validate_integer(parts[i], name="--max-concurrency", min_value=1) + i += 1 + elif parts[i] == "--max-retries": + i += 1 + if i >= len(parts): + raise ValueError("--max-retries requires a value") + result["max_retries"] = validate_integer(parts[i], name="--max-retries", min_value=0) + i += 1 + elif parts[i] == "--memory-labels": + i += 1 + if i >= len(parts): + raise ValueError("--memory-labels requires a value") + result["memory_labels"] = parse_memory_labels(parts[i]) + i += 1 + elif parts[i] == "--database": + i += 1 + if i >= len(parts): + raise ValueError("--database requires a value") + result["database"] = validate_database(database=parts[i]) + i += 1 + elif parts[i] == "--log-level": + i += 1 + if i >= len(parts): + raise ValueError("--log-level requires a value") + result["log_level"] = validate_log_level(log_level=parts[i]) + i += 1 + else: + logger.warning(f"Unknown argument: {parts[i]}") + i += 1 + + return result diff --git a/pyrit/cli/initializer_registry.py b/pyrit/cli/initializer_registry.py index a5c605304..a53e2a76c 100644 --- a/pyrit/cli/initializer_registry.py +++ b/pyrit/cli/initializer_registry.py @@ -1,20 +1,35 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + """ Initializer registry for discovering and cataloging PyRIT initializers. -This module provides functionality to discover all available PyRITInitializer subclasses +This module provides functionality to discover all available PyRITInitializer subclasses. + +PERFORMANCE OPTIMIZATION: +This module uses lazy imports and direct path computation to minimize import overhead: + +1. Lazy Imports via TYPE_CHECKING: PyRITInitializer is only imported for type checking, + not at runtime. Runtime imports happen inside methods when actually needed. + +2. Direct Path Computation: Computes PYRIT_PATH directly using __file__ instead of importing + from pyrit.common.path, avoiding loading of the pyrit package. """ import importlib.util import inspect import logging from pathlib import Path -from typing import Dict, List, TypedDict +from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict + +# Compute PYRIT_PATH directly to avoid importing pyrit package +# (which triggers heavy imports from __init__.py) +PYRIT_PATH = Path(__file__).parent.parent.resolve() -from pyrit.common.path import PYRIT_PATH -from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer +if TYPE_CHECKING: + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer logger = logging.getLogger(__name__) @@ -52,11 +67,13 @@ def __init__(self, *, discovery_path: Path | None = None) -> None: """ self._initializers: Dict[str, InitializerInfo] = {} self._initializer_paths: Dict[str, Path] = {} # Track file paths for collision detection + self._initializer_metadata: Optional[List[InitializerInfo]] = None if discovery_path is None: discovery_path = Path(PYRIT_PATH) / "setup" / "initializers" self._discovery_path = discovery_path + self._discover_initializers() def _discover_initializers(self) -> None: @@ -98,6 +115,9 @@ def _process_file(self, *, file_path: Path) -> None: Args: file_path (Path): Path to the Python file to process. """ + # Runtime import to avoid loading heavy modules at module level + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + # Use filename as the short name (e.g., "objective_target" instead of "scenarios.objective_target") short_name = file_path.stem @@ -183,8 +203,17 @@ def list_initializers(self) -> List[InitializerInfo]: List[InitializerInfo]: List of initializer information dictionaries, sorted by execution order and then by name. """ + # Return cached metadata if available + if self._initializer_metadata is not None: + return self._initializer_metadata + + # Build from discovered initializers initializers_list = list(self._initializers.values()) initializers_list.sort(key=lambda x: (x["execution_order"], x["name"])) + + # Cache for subsequent calls + self._initializer_metadata = initializers_list + return initializers_list def get_initializer_names(self) -> List[str]: @@ -230,6 +259,46 @@ def resolve_initializer_paths(self, *, initializer_names: list[str]) -> list[Pat return resolved_paths + def get_initializer_class(self, *, name: str) -> type["PyRITInitializer"]: + """ + Get the initializer class by name. + + Args: + name: The initializer name. + + Returns: + The initializer class. + + Raises: + ValueError: If initializer not found. + """ + import importlib.util + + initializer_info = self.get_initializer(name) + if initializer_info is None: + available = ", ".join(sorted(self.get_initializer_names())) + raise ValueError( + f"Initializer '{name}' not found.\n" + f"Available initializers: {available}\n" + f"Use 'pyrit_scan --list-initializers' to see detailed information." + ) + + initializer_file = self._initializer_paths.get(name) + if initializer_file is None: + raise ValueError(f"Could not locate file for initializer '{name}'.") + + # Load the module + spec = importlib.util.spec_from_file_location("initializer_module", initializer_file) + if spec is None or spec.loader is None: + raise ValueError(f"Failed to load initializer from {initializer_file}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Get the initializer class + initializer_class = getattr(module, initializer_info["class_name"]) + return initializer_class + @staticmethod def resolve_script_paths(*, script_paths: list[str]) -> list[Path]: """ diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py new file mode 100644 index 000000000..b85c3bb86 --- /dev/null +++ b/pyrit/cli/pyrit_scan.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +PyRIT CLI - Command-line interface for running security scenarios. + +This module provides the main entry point for the pyrit_scan command. +""" + +import asyncio +import sys +from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter + +from pyrit.cli import frontend_core + + +def parse_args(args=None) -> Namespace: + """Parse command-line arguments for the PyRIT scanner.""" + parser = ArgumentParser( + prog="pyrit_scan", + description="""PyRIT Scanner - Run security scenarios against AI systems + +Examples: + # List available scenarios and initializers + pyrit_scan --list-scenarios + pyrit_scan --list-initializers + + # Run a scenario with built-in initializers + pyrit_scan foundry_scenario --initializers openai_objective_target + + # Run with custom initialization scripts + pyrit_scan encoding_scenario --initialization-scripts ./my_config.py + + # Run specific strategies + pyrit_scan encoding_scenario --initializers openai_objective_target --strategies base64 rot13 + pyrit_scan foundry_scenario --initializers openai_objective_target --max-concurrency 10 --max-retries 3 + pyrit_scan encoding_scenario --initializers openai_objective_target --memory-labels '{"run_id":"test123"}' +""", + formatter_class=RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--log-level", + type=frontend_core.validate_log_level_argparse, + default="WARNING", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING)", + ) + + parser.add_argument( + "--list-scenarios", + action="store_true", + help="List all available scenarios and exit", + ) + + parser.add_argument( + "--list-initializers", + action="store_true", + help="List all available scenario initializers and exit", + ) + + parser.add_argument( + "scenario_name", + type=str, + nargs="?", + help="Name of the scenario to run", + ) + + parser.add_argument( + "--database", + type=frontend_core.validate_database_argparse, + default=frontend_core.SQLITE, + help=( + f"Database type to use for memory storage ({frontend_core.IN_MEMORY}, " + f"{frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: {frontend_core.SQLITE})" + ), + ) + + parser.add_argument( + "--initializers", + type=str, + nargs="+", + help=frontend_core.ARG_HELP["initializers"], + ) + + parser.add_argument( + "--initialization-scripts", + type=str, + nargs="+", + help=frontend_core.ARG_HELP["initialization_scripts"], + ) + + parser.add_argument( + "--strategies", + "-s", + type=str, + nargs="+", + dest="scenario_strategies", + help=frontend_core.ARG_HELP["scenario_strategies"], + ) + + parser.add_argument( + "--max-concurrency", + type=frontend_core.positive_int, + help=frontend_core.ARG_HELP["max_concurrency"], + ) + + parser.add_argument( + "--max-retries", + type=frontend_core.non_negative_int, + help=frontend_core.ARG_HELP["max_retries"], + ) + + parser.add_argument( + "--memory-labels", + type=str, + help=frontend_core.ARG_HELP["memory_labels"], + ) + + return parser.parse_args(args) + + +def main(args=None) -> int: + """Main entry point for the PyRIT scanner CLI.""" + print("Starting PyRIT...") + sys.stdout.flush() + + try: + parsed_args = parse_args(args) + except SystemExit as e: + return e.code if isinstance(e.code, int) else 1 + + # Handle list commands (don't need full context) + if parsed_args.list_scenarios: + # Simple context just for listing + initialization_scripts = None + if parsed_args.initialization_scripts: + try: + initialization_scripts = frontend_core.resolve_initialization_scripts( + script_paths=parsed_args.initialization_scripts + ) + except FileNotFoundError as e: + print(f"Error: {e}") + return 1 + + context = frontend_core.FrontendCore( + database=parsed_args.database, + initialization_scripts=initialization_scripts, + log_level=parsed_args.log_level, + ) + + return frontend_core.print_scenarios_list(context=context) + + if parsed_args.list_initializers: + # Discover from scenarios directory + scenarios_path = frontend_core.get_default_initializer_discovery_path() + + context = frontend_core.FrontendCore(log_level=parsed_args.log_level) + return frontend_core.print_initializers_list(context=context, discovery_path=scenarios_path) + + # Verify scenario was provided + if not parsed_args.scenario_name: + print("Error: No scenario specified. Use --help for usage information.") + return 1 + + # Run scenario + try: + # Collect initialization scripts + initialization_scripts = None + if parsed_args.initialization_scripts: + initialization_scripts = frontend_core.resolve_initialization_scripts( + script_paths=parsed_args.initialization_scripts + ) + + # Create context with initializers + context = frontend_core.FrontendCore( + database=parsed_args.database, + initialization_scripts=initialization_scripts, + initializer_names=parsed_args.initializers, + log_level=parsed_args.log_level, + ) + + # Parse memory labels if provided + memory_labels = None + if parsed_args.memory_labels: + memory_labels = frontend_core.parse_memory_labels(json_string=parsed_args.memory_labels) + + # Run scenario + asyncio.run( + frontend_core.run_scenario_async( + scenario_name=parsed_args.scenario_name, + context=context, + scenario_strategies=parsed_args.scenario_strategies, + max_concurrency=parsed_args.max_concurrency, + max_retries=parsed_args.max_retries, + memory_labels=memory_labels, + ) + ) + return 0 + + except Exception as e: + print(f"\nError: {e}") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py new file mode 100644 index 000000000..6eaab8f2b --- /dev/null +++ b/pyrit/cli/pyrit_shell.py @@ -0,0 +1,464 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +PyRIT Shell - Interactive REPL for PyRIT. + +This module provides an interactive shell where PyRIT modules are loaded once +at startup, making subsequent commands instant. +""" + +from __future__ import annotations + +import asyncio +import cmd +import sys +import threading +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyrit.models.scenario_result import ScenarioResult + +from pyrit.cli import frontend_core + + +class PyRITShell(cmd.Cmd): + """ + Interactive shell for PyRIT. + + Commands: + list-scenarios - List all available scenarios + list-initializers - List all available initializers + run [opts] - Run a scenario with optional parameters + scenario-history - List all previous scenario runs + print-scenario [N] - Print detailed results for scenario run(s) + help [command] - Show help for a command + clear - Clear the screen + exit (quit, q) - Exit the shell + + Shell Startup Options: + --database Database type (InMemory, SQLite, AzureSQL) - default for all runs + --log-level Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) - default for all runs + + Run Command Options: + --initializers ... Built-in initializers to run before the scenario + --initialization-scripts <...> Custom Python scripts to run before the scenario + --strategies, -s ... Strategy names to use + --max-concurrency Maximum concurrent operations + --max-retries Maximum retry attempts + --memory-labels JSON string of labels + --database Override default database for this run + --log-level Override default log level for this run + """ + + intro = """ +╔════════════════════════════════════════════════════════════════════════╗ +║ ║ +║ ██████╗ ██╗ ██╗██████╗ ██╗████████╗ ║ +║ ██╔══██╗╚██╗ ██╔╝██╔══██╗██║╚══██╔══╝ ║ +║ ██████╔╝ ╚████╔╝ ██████╔╝██║ ██║ ║ +║ ██╔═══╝ ╚██╔╝ ██╔══██╗██║ ██║ ║ +║ ██║ ██║ ██║ ██║██║ ██║ ║ +║ ╚═╝ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝ ║ +║ ║ +║ Python Risk Identification Tool ║ +║ Interactive Shell ║ +║ ║ +╠════════════════════════════════════════════════════════════════════════╣ +║ ║ +║ Commands: ║ +║ • list-scenarios - See all available scenarios ║ +║ • list-initializers - See all available initializers ║ +║ • run [opts] - Execute a security scenario ║ +║ • scenario-history - View your session history ║ +║ • print-scenario [N] - Display detailed results ║ +║ • help [command] - Get help on any command ║ +║ • exit - Quit the shell ║ +║ ║ +║ Quick Start: ║ +║ pyrit> list-scenarios ║ +║ pyrit> run foundry_scenario --initializers openai_objective_target ║ +║ ║ +╚════════════════════════════════════════════════════════════════════════╝ +""" + prompt = "pyrit> " + + def __init__( + self, + context: frontend_core.FrontendCore, + ): + """ + Initialize the PyRIT shell. + + Args: + context: PyRIT context with loaded registries. + """ + super().__init__() + self.context = context + self.default_database = context._database + self.default_log_level = context._log_level + + # Track scenario execution history: list of (command_string, ScenarioResult) tuples + self._scenario_history: list[tuple[str, ScenarioResult]] = [] + + # Initialize PyRIT in background thread for faster startup + self._init_thread = threading.Thread(target=self._background_init, daemon=True) + self._init_complete = threading.Event() + self._init_thread.start() + + def _background_init(self): + """Initialize PyRIT modules in the background. This dramatically speeds up shell startup.""" + print("Loading PyRIT modules...") + sys.stdout.flush() + self.context.initialize() + self._init_complete.set() + + def _ensure_initialized(self): + """Wait for initialization to complete if not already done.""" + if not self._init_complete.is_set(): + print("Waiting for PyRIT initialization to complete...") + sys.stdout.flush() + self._init_complete.wait() + + def do_list_scenarios(self, arg): + """List all available scenarios.""" + self._ensure_initialized() + try: + frontend_core.print_scenarios_list(context=self.context) + except Exception as e: + print(f"Error listing scenarios: {e}") + + def do_list_initializers(self, arg): + """List all available initializers.""" + self._ensure_initialized() + try: + # Parse optional path argument + discovery_path = Path(arg.strip()) if arg.strip() else None + frontend_core.print_initializers_list(context=self.context, discovery_path=discovery_path) + except Exception as e: + print(f"Error listing initializers: {e}") + + def do_run(self, line): + """ + Run a scenario. + + Usage: + run [options] + + Options: + --initializers ... Built-in initializers to run before the scenario + --initialization-scripts <...> Custom Python scripts to run before the scenario + --strategies, -s ... Strategy names to use + --max-concurrency Maximum concurrent operations + --max-retries Maximum retry attempts + --memory-labels JSON string of labels (e.g., '{"key":"value"}') + --database Override default database (InMemory, SQLite, AzureSQL) + --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + + Examples: + run encoding_scenario --initializers openai_objective_target + run encoding_scenario --initializers custom_target --strategies base64 rot13 + run foundry_scenario --initializers openai_objective_target --max-concurrency 10 --max-retries 3 + run encoding_scenario --initializers custom_target --memory-labels '{"run_id":"test123","env":"dev"}' + run foundry_scenario --initializers openai_objective_target -s jailbreak crescendo + run encoding_scenario --initializers openai_objective_target --database InMemory --log-level DEBUG + run foundry_scenario --initialization-scripts ./my_custom_init.py -s all + + Note: + Every scenario requires an initializer (--initializers or --initialization-scripts). + Database and log-level defaults are set at shell startup but can be overridden per-run. + Initializers are specified per-run to allow different setups for different scenarios. + """ + self._ensure_initialized() + if not line.strip(): + print("Error: Specify a scenario name") + print("\nUsage: run [options]") + print("\nNote: Every scenario requires an initializer.") + print("\nOptions:") + print(f" --initializers ... {frontend_core.ARG_HELP['initializers']} (REQUIRED)") + print( + f" --initialization-scripts <...> {frontend_core.ARG_HELP['initialization_scripts']} (alternative to --initializers)" + ) + print(f" --strategies, -s ... {frontend_core.ARG_HELP['scenario_strategies']}") + print(f" --max-concurrency {frontend_core.ARG_HELP['max_concurrency']}") + print(f" --max-retries {frontend_core.ARG_HELP['max_retries']}") + print(f" --memory-labels {frontend_core.ARG_HELP['memory_labels']}") + print( + f" --database Override default database ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL})" + ) + print( + f" --log-level Override default log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)" + ) + print("\nExample:") + print(" run foundry_scenario --initializers openai_objective_target") + print("\nType 'help run' for more details and examples") + return + + # Parse arguments using shared parser + try: + args = frontend_core.parse_run_arguments(args_string=line) + except ValueError as e: + print(f"Error: {e}") + return + + # Resolve initialization scripts if provided + resolved_scripts = None + if args["initialization_scripts"]: + try: + resolved_scripts = frontend_core.resolve_initialization_scripts( + script_paths=args["initialization_scripts"] + ) + except FileNotFoundError as e: + print(f"Error: {e}") + return + + # Create a context for this run with overrides + run_context = frontend_core.FrontendCore( + database=args["database"] or self.default_database, + initialization_scripts=resolved_scripts, + initializer_names=args["initializers"], + log_level=args["log_level"] or self.default_log_level, + ) + # Use the existing registries (don't reinitialize) + run_context._scenario_registry = self.context._scenario_registry + run_context._initializer_registry = self.context._initializer_registry + run_context._initialized = True + + try: + result = asyncio.run( + frontend_core.run_scenario_async( + scenario_name=args["scenario_name"], + context=run_context, + scenario_strategies=args["scenario_strategies"], + max_concurrency=args["max_concurrency"], + max_retries=args["max_retries"], + memory_labels=args["memory_labels"], + ) + ) + # Store the command and result in history + self._scenario_history.append((line, result)) + except ValueError as e: + print(f"Error: {e}") + except Exception as e: + print(f"Error running scenario: {e}") + import traceback + + traceback.print_exc() + + def do_scenario_history(self, arg): + """ + Display history of scenario runs. + + Usage: + scenario-history + + Shows a numbered list of all scenario runs with the commands used. + """ + if not self._scenario_history: + print("No scenario runs in history.") + return + + print("\nScenario Run History:") + print("=" * 80) + for idx, (command, _) in enumerate(self._scenario_history, start=1): + print(f"{idx}) {command}") + print("=" * 80) + print(f"\nTotal runs: {len(self._scenario_history)}") + print("\nUse 'print-scenario ' to view detailed results for a specific run.") + print("Use 'print-scenario' to view detailed results for all runs.") + + def do_print_scenario(self, arg): + """ + Print detailed results for scenario runs. + + Usage: + print-scenario Print all scenario results + print-scenario Print results for scenario run number N + + Examples: + print-scenario Show all previous scenario results + print-scenario 1 Show results from first scenario run + print-scenario 3 Show results from third scenario run + """ + if not self._scenario_history: + print("No scenario runs in history.") + return + + # Parse argument + arg = arg.strip() + + if not arg: + # Print all scenarios + print("\nPrinting all scenario results:") + print("=" * 80) + for idx, (command, result) in enumerate(self._scenario_history, start=1): + print(f"\n{'#' * 80}") + print(f"Scenario Run #{idx}: {command}") + print(f"{'#' * 80}") + from pyrit.scenarios.printer.console_printer import ( + ConsoleScenarioResultPrinter, + ) + + printer = ConsoleScenarioResultPrinter() + asyncio.run(printer.print_summary_async(result)) + else: + # Print specific scenario + try: + scenario_num = int(arg) + if scenario_num < 1 or scenario_num > len(self._scenario_history): + print(f"Error: Scenario number must be between 1 and {len(self._scenario_history)}") + return + + command, result = self._scenario_history[scenario_num - 1] + print(f"\nScenario Run #{scenario_num}: {command}") + print("=" * 80) + from pyrit.scenarios.printer.console_printer import ( + ConsoleScenarioResultPrinter, + ) + + printer = ConsoleScenarioResultPrinter() + asyncio.run(printer.print_summary_async(result)) + except ValueError: + print(f"Error: Invalid scenario number '{arg}'. Must be an integer.") + + def do_help(self, arg): + """Show help. Usage: help [command].""" + if not arg: + # Show general help + super().do_help(arg) + print("\n" + "=" * 70) + print("Shell Startup Options:") + print("=" * 70) + print(" --database ") + print(" Default database type: InMemory, SQLite, or AzureSQL") + print(" Default: SQLite") + print(" Can be overridden per-run with 'run --database '") + print() + print(" --log-level ") + print(" Default logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL") + print(" Default: WARNING") + print(" Can be overridden per-run with 'run --log-level '") + print() + print("=" * 70) + print("Run Command Options (specified when running scenarios):") + print("=" * 70) + print(" --initializers [ ...] (REQUIRED)") + print(f" {frontend_core.ARG_HELP['initializers']}") + print(" Every scenario requires at least one initializer") + print(" Example: run foundry_scenario --initializers openai_objective_target") + print() + print(" --initialization-scripts [ ...] (Alternative to --initializers)") + print(f" {frontend_core.ARG_HELP['initialization_scripts']}") + print(" Example: run foundry_scenario --initialization-scripts ./my_init.py") + print() + print(" --strategies, -s [ ...]") + print(f" {frontend_core.ARG_HELP['scenario_strategies']}") + print(" Example: run encoding_scenario --strategies base64 rot13") + print() + print(" --max-concurrency ") + print(f" {frontend_core.ARG_HELP['max_concurrency']}") + print() + print(" --max-retries ") + print(f" {frontend_core.ARG_HELP['max_retries']}") + print() + print(" --memory-labels ") + print(f" {frontend_core.ARG_HELP['memory_labels']}") + print(' Example: run foundry_scenario --memory-labels \'{"env":"test"}\'') + print() + print("Start the shell like:") + print(" pyrit_shell") + print(" pyrit_shell --database InMemory --log-level DEBUG") + else: + # Show help for specific command + super().do_help(arg) + + def do_exit(self, arg): + """Exit the shell. Aliases: quit, q.""" + print("\nGoodbye!") + return True + + def do_clear(self, arg): + """Clear the screen.""" + import os + + os.system("cls" if os.name == "nt" else "clear") + + # Shortcuts and aliases + do_quit = do_exit + do_q = do_exit + do_EOF = do_exit # Ctrl+D on Unix, Ctrl+Z on Windows + + def emptyline(self) -> bool: + """Don't repeat last command on empty line.""" + return False + + def default(self, line): + """Handle unknown commands and convert hyphens to underscores.""" + # Try converting hyphens to underscores for command lookup + parts = line.split(None, 1) + if parts: + cmd_with_underscores = parts[0].replace("-", "_") + method_name = f"do_{cmd_with_underscores}" + + if hasattr(self, method_name): + # Call the method with the rest of the line as argument + arg = parts[1] if len(parts) > 1 else "" + return getattr(self, method_name)(arg) + + print(f"Unknown command: {line}") + print("Type 'help' or '?' for available commands") + + +def main(): + """Entry point for pyrit_shell.""" + import argparse + + parser = argparse.ArgumentParser( + prog="pyrit_shell", + description="PyRIT Interactive Shell - Load modules once, run commands instantly", + ) + + parser.add_argument( + "--database", + choices=[frontend_core.IN_MEMORY, frontend_core.SQLITE, frontend_core.AZURE_SQL], + default=frontend_core.SQLITE, + help=f"Default database type to use ({frontend_core.IN_MEMORY}, {frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: {frontend_core.SQLITE}, can be overridden per-run)", + ) + + parser.add_argument( + "--log-level", + type=str, + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + default="WARNING", + help="Default logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) (default: WARNING, can be overridden per-run)", + ) + + args = parser.parse_args() + + # Create context (initializers are specified per-run, not at startup) + context = frontend_core.FrontendCore( + database=args.database, + initialization_scripts=None, + initializer_names=None, + log_level=args.log_level, + ) + + # Start shell + try: + shell = PyRITShell(context) + shell.cmdloop() + return 0 + except KeyboardInterrupt: + print("\n\nInterrupted. Goodbye!") + return 0 + except Exception as e: + print(f"\nError: {e}") + import traceback + + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyrit/cli/scenario_registry.py b/pyrit/cli/scenario_registry.py index f9ca8c926..4fbf849c8 100644 --- a/pyrit/cli/scenario_registry.py +++ b/pyrit/cli/scenario_registry.py @@ -1,11 +1,24 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + """ Scenario registry for discovering and instantiating PyRIT scenarios. This module provides functionality to discover all available Scenario subclasses from the pyrit.scenarios.scenarios module and from user-defined initialization scripts. + +PERFORMANCE OPTIMIZATION: +This module uses lazy imports to minimize overhead during CLI operations: + +1. Lazy Imports via TYPE_CHECKING: Heavy dependencies (like Scenario base class) are only + imported for type checking, not at runtime. Runtime imports happen inside methods only + when actually needed. + +2. Direct Path Computation: Computes PYRIT_PATH directly using __file__ instead of importing + from pyrit.common.path, which would trigger loading the entire pyrit package (including + heavy dependencies like transformers). """ import importlib @@ -13,9 +26,16 @@ import logging import pkgutil from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Type + +# Compute PYRIT_PATH directly to avoid importing pyrit package +# (which triggers heavy imports from __init__.py) +PYRIT_PATH = Path(__file__).parent.parent.resolve() -from pyrit.scenarios.scenario import Scenario +# Lazy import to avoid loading heavy scenario modules when just listing scenarios +if TYPE_CHECKING: + from pyrit.cli.frontend_core import ScenarioInfo + from pyrit.scenarios.scenario import Scenario logger = logging.getLogger(__name__) @@ -32,9 +52,16 @@ class ScenarioRegistry: """ def __init__(self) -> None: - """Initialize the scenario registry.""" + """Initialize the scenario registry with lazy discovery.""" self._scenarios: Dict[str, Type[Scenario]] = {} - self._discover_builtin_scenarios() + self._scenario_metadata: Optional[List[ScenarioInfo]] = None + self._discovered = False + + def _ensure_discovered(self) -> None: + """Ensure scenarios have been discovered. Discovers on first call only.""" + if not self._discovered: + self._discover_builtin_scenarios() + self._discovered = True def _discover_builtin_scenarios(self) -> None: """ @@ -43,6 +70,8 @@ def _discover_builtin_scenarios(self) -> None: This method dynamically imports all modules in the scenarios package and registers any Scenario subclasses found. """ + from pyrit.scenarios.scenario import Scenario + try: import pyrit.scenarios.scenarios as scenarios_package @@ -92,6 +121,8 @@ def discover_user_scenarios(self) -> None: User scenarios will override built-in scenarios with the same name. """ + from pyrit.scenarios.scenario import Scenario + try: # Check the global namespace for Scenario subclasses import sys @@ -155,20 +186,29 @@ def get_scenario(self, name: str) -> Optional[Type[Scenario]]: Returns: Optional[Type[Scenario]]: The scenario class, or None if not found. """ + self._ensure_discovered() return self._scenarios.get(name) - def list_scenarios(self) -> list[dict[str, Sequence[Any]]]: + def list_scenarios(self) -> List[ScenarioInfo]: """ List all available scenarios with their metadata. Returns: - List[Dict[str, str]]: List of scenario information dictionaries containing: + List[ScenarioInfo]: List of scenario information dictionaries containing: - name: Scenario identifier - class_name: Class name - description: Full class docstring - default_strategy: The default strategy used when none specified + - all_strategies: All available strategy values + - aggregate_strategies: Aggregate strategy values """ - scenarios_info = [] + # If we already have metadata, return it + if self._scenario_metadata is not None: + return self._scenario_metadata + + # Discover scenarios and build metadata + self._ensure_discovered() + scenarios_info: List[ScenarioInfo] = [] for name, scenario_class in sorted(self._scenarios.items()): # Extract full docstring as description, clean up whitespace @@ -193,6 +233,9 @@ def list_scenarios(self) -> list[dict[str, Sequence[Any]]]: } ) + # Cache metadata for subsequent calls + self._scenario_metadata = scenarios_info + return scenarios_info def get_scenario_names(self) -> List[str]: @@ -202,4 +245,5 @@ def get_scenario_names(self) -> List[str]: Returns: List[str]: Sorted list of scenario identifiers. """ + self._ensure_discovered() return sorted(self._scenarios.keys()) diff --git a/pyrit/show_versions.py b/pyrit/show_versions.py index 09f4ce061..c3101f5ef 100644 --- a/pyrit/show_versions.py +++ b/pyrit/show_versions.py @@ -5,8 +5,7 @@ Utility methods to print system info for debugging. Adapted from :py:func:`pandas.show_versions` and :py:func:`sklearn.show_versions`. -""" # noqa: RST304 - +""" import platform import sys diff --git a/tests/unit/cli/test_cli_main.py b/tests/unit/cli/test_cli_main.py deleted file mode 100644 index e5620cac7..000000000 --- a/tests/unit/cli/test_cli_main.py +++ /dev/null @@ -1,628 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Unit tests for the PyRIT CLI main module. - -Tests command-line argument parsing, scenario listing, and main entry point. -For InitializerRegistry tests, see test_initializer_registry.py. -For ScenarioRegistry tests, see test_scenario_registry.py. -""" - -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from pyrit.cli.__main__ import list_scenarios, main, parse_args, run_scenario_async -from pyrit.scenarios import ScenarioStrategy - - -class TestParseArgs: - """Tests for command-line argument parsing.""" - - def test_parse_args_list_scenarios(self): - """Test parsing --list-scenarios flag.""" - args = parse_args(["--list-scenarios"]) - assert args.list_scenarios is True - assert args.verbose is False - - def test_parse_args_list_initializers(self): - """Test parsing --list-initializers flag.""" - args = parse_args(["--list-initializers"]) - assert args.list_initializers is True - assert args.verbose is False - - def test_parse_args_scenario_basic(self): - """Test parsing with just scenario name.""" - args = parse_args(["encoding_scenario"]) - assert args.scenario_name == "encoding_scenario" - assert args.database == "SQLite" - assert args.verbose is False - - def test_parse_args_scenario_with_database(self): - """Test parsing scenario with database option.""" - args = parse_args(["foundry_scenario", "--database", "InMemory"]) - assert args.scenario_name == "foundry_scenario" - assert args.database == "InMemory" - - def test_parse_args_scenario_with_initializers(self): - """Test parsing scenario with initializers option.""" - args = parse_args(["encoding_scenario", "--initializers", "simple", "airt"]) - assert args.scenario_name == "encoding_scenario" - assert args.initializers == ["simple", "airt"] - - def test_parse_args_scenario_with_dot_notation_initializers(self): - """Test parsing scenario with dot notation initializers.""" - args = parse_args(["foundry_scenario", "--initializers", "scenarios.objective_target"]) - assert args.scenario_name == "foundry_scenario" - assert args.initializers == ["scenarios.objective_target"] - - def test_parse_args_scenario_with_initialization_scripts(self): - """Test parsing scenario with initialization scripts.""" - args = parse_args(["foundry_scenario", "--initialization-scripts", "script1.py", "script2.py"]) - assert args.scenario_name == "foundry_scenario" - assert args.initialization_scripts == ["script1.py", "script2.py"] - - def test_parse_args_scenario_all_options(self): - """Test parsing scenario with all options.""" - args = parse_args( - [ - "--verbose", - "encoding_scenario", - "--database", - "AzureSQL", - "--initializers", - "simple", - "--initialization-scripts", - "custom.py", - ] - ) - assert args.scenario_name == "encoding_scenario" - assert args.database == "AzureSQL" - assert args.initializers == ["simple"] - assert args.initialization_scripts == ["custom.py"] - assert args.verbose is True - - def test_parse_args_no_scenario(self): - """Test parsing with no scenario name.""" - args = parse_args([]) - assert args.scenario_name is None - - def test_parse_args_invalid_database(self): - """Test parsing with invalid database choice.""" - with pytest.raises(SystemExit): - parse_args(["encoding_scenario", "--database", "InvalidDB"]) - - def test_parse_args_list_scenarios_with_initialization_scripts(self): - """Test parsing --list-scenarios with initialization scripts.""" - args = parse_args(["--list-scenarios", "--initialization-scripts", "custom.py"]) - assert args.list_scenarios is True - assert args.initialization_scripts == ["custom.py"] - - def test_parse_args_scenario_with_single_strategy(self): - """Test parsing scenario with single strategy.""" - args = parse_args(["foundry_scenario", "--scenario-strategies", "base64"]) - assert args.scenario_name == "foundry_scenario" - assert args.scenario_strategies == ["base64"] - - def test_parse_args_scenario_with_multiple_strategies(self): - """Test parsing scenario with multiple strategies.""" - args = parse_args(["foundry_scenario", "--scenario-strategies", "base64", "rot13", "morse"]) - assert args.scenario_name == "foundry_scenario" - assert args.scenario_strategies == ["base64", "rot13", "morse"] - - def test_parse_args_scenario_with_strategies_and_initializers(self): - """Test parsing scenario with both strategies and initializers.""" - args = parse_args( - [ - "foundry_scenario", - "--initializers", - "simple", - "objective_target", - "--scenario-strategies", - "base64", - "atbash", - ] - ) - assert args.scenario_name == "foundry_scenario" - assert args.initializers == ["simple", "objective_target"] - assert args.scenario_strategies == ["base64", "atbash"] - - -class TestListScenarios: - """Tests for list_scenarios function.""" - - def test_list_scenarios_no_scenarios(self, capsys): - """Test list_scenarios with no scenarios available.""" - registry = MagicMock() - registry.list_scenarios.return_value = [] - - list_scenarios(registry=registry) - - captured = capsys.readouterr() - assert "No scenarios found" in captured.out - - def test_list_scenarios_with_scenarios(self, capsys): - """Test list_scenarios with available scenarios.""" - registry = MagicMock() - registry.list_scenarios.return_value = [ - { - "name": "encoding_scenario", - "class_name": "EncodingScenario", - "description": "Test encoding attacks on AI models", - "default_strategy": "all", - "all_strategies": ["base64", "rot13"], - "aggregate_strategies": ["all"], - }, - { - "name": "foundry_scenario", - "class_name": "FoundryScenario", - "description": "Comprehensive security testing scenario", - "default_strategy": "easy", - "all_strategies": ["base64", "crescendo"], - "aggregate_strategies": ["easy", "moderate", "difficult"], - }, - ] - - list_scenarios(registry=registry) - - captured = capsys.readouterr() - assert "Available Scenarios:" in captured.out - assert "encoding_scenario" in captured.out - assert "EncodingScenario" in captured.out - assert "foundry_scenario" in captured.out - assert "FoundryScenario" in captured.out - assert "Total scenarios: 2" in captured.out - - -class TestMain: - """Tests for main function.""" - - def test_main_no_scenario(self, capsys): - """Test main with no scenario returns error.""" - result = main([]) - assert result == 1 - - captured = capsys.readouterr() - assert "Error: No scenario specified" in captured.out - - @patch("pyrit.cli.__main__.ScenarioRegistry") - def test_main_list_scenarios(self, mock_registry_class, capsys): - """Test main with --list-scenarios flag.""" - mock_registry = MagicMock() - mock_registry.list_scenarios.return_value = [ - { - "name": "encoding_scenario", - "class_name": "EncodingScenario", - "description": "Test scenario", - "default_strategy": "all", - "all_strategies": ["base64"], - "aggregate_strategies": ["all"], - } - ] - mock_registry_class.return_value = mock_registry - - result = main(["--list-scenarios"]) - - assert result == 0 - mock_registry_class.assert_called_once() - mock_registry.list_scenarios.assert_called_once() - - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.InitializerRegistry") - @patch("pyrit.cli.__main__.ScenarioRegistry") - def test_main_list_scenarios_with_scripts( - self, mock_registry_class, mock_init_registry_class, mock_init_pyrit, capsys - ): - """Test main with --list-scenarios and initialization scripts.""" - mock_registry = MagicMock() - mock_registry.list_scenarios.return_value = [] - mock_registry_class.return_value = mock_registry - - # Mock the resolve_script_paths static method - mock_init_registry_class.resolve_script_paths.return_value = [Path("/fake/custom.py")] - - result = main(["--list-scenarios", "--initialization-scripts", "custom.py"]) - - assert result == 0 - mock_init_pyrit.assert_called_once() - mock_registry.discover_user_scenarios.assert_called_once() - - @patch("pyrit.cli.__main__.list_initializers") - def test_main_list_initializers(self, mock_list_init): - """Test main with --list-initializers flag.""" - result = main(["--list-initializers"]) - - assert result == 0 - mock_list_init.assert_called_once() - - @patch("pyrit.cli.__main__.asyncio.run") - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.ScenarioRegistry") - def test_main_scenario_command(self, mock_registry_class, mock_init_pyrit, mock_asyncio_run): - """Test main with scenario name.""" - mock_registry = MagicMock() - mock_registry_class.return_value = mock_registry - - result = main(["encoding_scenario"]) - - assert result == 0 - mock_init_pyrit.assert_called_once() - mock_registry_class.assert_called_once() - mock_asyncio_run.assert_called_once() - - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.ScenarioRegistry") - @patch("pyrit.cli.__main__.InitializerRegistry") - def test_main_scenario_with_missing_script( - self, mock_init_registry_class, mock_registry_class, mock_init_pyrit, capsys - ): - """Test main with scenario and missing initialization script.""" - # Make resolve_script_paths raise FileNotFoundError - mock_init_registry_class.resolve_script_paths.side_effect = FileNotFoundError( - "Initialization script not found: /fake/path/missing.py" - ) - - result = main(["encoding_scenario", "--initialization-scripts", "missing.py"]) - - assert result == 1 - captured = capsys.readouterr() - assert "Initialization script not found" in captured.out - - @patch("pyrit.cli.__main__.asyncio.run") - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.ScenarioRegistry") - @patch("pyrit.cli.__main__.InitializerRegistry") - def test_main_scenario_with_initializers( - self, mock_init_registry_class, mock_registry_class, mock_init_pyrit, mock_asyncio_run - ): - """Test main with scenario and built-in initializers.""" - mock_registry = MagicMock() - mock_registry_class.return_value = mock_registry - - # Mock the resolve_initializer_paths method - mock_init_registry_instance = MagicMock() - mock_init_registry_instance.resolve_initializer_paths.return_value = [Path("/fake/simple.py")] - mock_init_registry_class.return_value = mock_init_registry_instance - - result = main(["encoding_scenario", "--initializers", "simple"]) - - assert result == 0 - mock_init_pyrit.assert_called_once() - - @patch("pyrit.cli.__main__.asyncio.run") - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.ScenarioRegistry") - @patch("pyrit.cli.__main__.InitializerRegistry") - def test_main_scenario_with_invalid_initializer( - self, mock_init_registry_class, mock_registry_class, mock_init_pyrit, mock_asyncio_run, capsys - ): - """Test main with scenario and invalid initializer name.""" - # Mock the registry instance to raise ValueError - mock_init_registry_instance = MagicMock() - mock_init_registry_instance.resolve_initializer_paths.side_effect = ValueError( - "Built-in initializer 'invalid' not found" - ) - mock_init_registry_class.return_value = mock_init_registry_instance - - result = main(["encoding_scenario", "--initializers", "invalid"]) - - assert result == 1 - captured = capsys.readouterr() - assert "Built-in initializer 'invalid' not found" in captured.out - - @patch("pyrit.cli.__main__.asyncio.run") - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.ScenarioRegistry") - def test_main_scenario_with_database_option(self, mock_registry_class, mock_init_pyrit, mock_asyncio_run): - """Test main with scenario and database option.""" - mock_registry = MagicMock() - mock_registry_class.return_value = mock_registry - - result = main(["encoding_scenario", "--database", "InMemory"]) - - assert result == 0 - # Verify initialize_pyrit was called with InMemory - call_args = mock_init_pyrit.call_args - assert call_args[1]["memory_db_type"] == "InMemory" - - @patch("pyrit.cli.__main__.asyncio.run") - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.ScenarioRegistry") - def test_main_scenario_exception_handling(self, mock_registry_class, mock_init_pyrit, mock_asyncio_run, capsys): - """Test main handles exceptions properly.""" - mock_init_pyrit.side_effect = Exception("Test error") - - result = main(["encoding_scenario"]) - - assert result == 1 - captured = capsys.readouterr() - assert "Error: Test error" in captured.out - - @patch("pyrit.cli.__main__.asyncio.run") - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.ScenarioRegistry") - def test_main_scenario_with_single_strategy(self, mock_registry_class, mock_init_pyrit, mock_asyncio_run): - """Test main with scenario and single strategy.""" - mock_registry = MagicMock() - mock_registry_class.return_value = mock_registry - - result = main(["foundry_scenario", "--scenario-strategies", "base64"]) - - assert result == 0 - mock_init_pyrit.assert_called_once() - mock_registry_class.assert_called_once() - - # Verify asyncio.run was called with the correct parameters - mock_asyncio_run.assert_called_once() - - @patch("pyrit.cli.__main__.asyncio.run") - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.ScenarioRegistry") - def test_main_scenario_with_multiple_strategies(self, mock_registry_class, mock_init_pyrit, mock_asyncio_run): - """Test main with scenario and multiple strategies.""" - mock_registry = MagicMock() - mock_registry_class.return_value = mock_registry - - result = main(["foundry_scenario", "--scenario-strategies", "base64", "rot13", "morse"]) - - assert result == 0 - mock_init_pyrit.assert_called_once() - mock_registry_class.assert_called_once() - mock_asyncio_run.assert_called_once() - - @patch("pyrit.cli.__main__.asyncio.run") - @patch("pyrit.cli.__main__.initialize_pyrit") - @patch("pyrit.cli.__main__.ScenarioRegistry") - @patch("pyrit.cli.__main__.InitializerRegistry") - def test_main_scenario_with_strategies_and_initializers( - self, mock_init_registry_class, mock_registry_class, mock_init_pyrit, mock_asyncio_run - ): - """Test main with scenario, strategies, and initializers.""" - mock_registry = MagicMock() - mock_registry_class.return_value = mock_registry - - # Mock the resolve_initializer_paths method - mock_init_registry_instance = MagicMock() - mock_init_registry_instance.resolve_initializer_paths.return_value = [Path("/fake/simple.py")] - mock_init_registry_class.return_value = mock_init_registry_instance - - result = main( - [ - "foundry_scenario", - "--initializers", - "simple", - "objective_target", - "--scenario-strategies", - "base64", - "atbash", - ] - ) - - assert result == 0 - mock_init_pyrit.assert_called_once() - mock_asyncio_run.assert_called_once() - - -class TestRunScenarioAsync: - """Tests for run_scenario_async function.""" - - @pytest.mark.asyncio - async def test_run_scenario_async_scenario_not_found(self): - """Test run_scenario_async with non-existent scenario.""" - mock_registry = MagicMock() - mock_registry.get_scenario.return_value = None - mock_registry.get_scenario_names.return_value = ["encoding_scenario", "foundry_scenario"] - - with pytest.raises(ValueError, match="Scenario 'invalid_scenario' not found"): - await run_scenario_async(scenario_name="invalid_scenario", registry=mock_registry, scenario_strategies=None) - - @pytest.mark.asyncio - async def test_run_scenario_async_without_strategies(self): - """Test run_scenario_async without providing strategies.""" - # Create a mock scenario class - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_scenario_instance.name = "Test Scenario" - mock_scenario_instance.atomic_attack_count = 2 - - # Use AsyncMock for async methods - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=MagicMock()) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_class.__name__ = "TestScenario" - - mock_registry = MagicMock() - mock_registry.get_scenario.return_value = mock_scenario_class - - # Mock the printer with async method - mock_printer = MagicMock() - mock_printer.print_summary_async = AsyncMock() - with patch("pyrit.cli.__main__.ConsoleScenarioResultPrinter", return_value=mock_printer): - await run_scenario_async(scenario_name="test_scenario", registry=mock_registry, scenario_strategies=None) - - # Verify scenario was instantiated without strategies - mock_scenario_class.assert_called_once_with() - - @pytest.mark.asyncio - async def test_run_scenario_async_with_single_strategy(self): - """Test run_scenario_async with a single strategy.""" - - # Create a mock strategy enum - class MockStrategy(ScenarioStrategy): - ALL = ("all", {"all"}) - Base64 = ("base64", {"test"}) - - @classmethod - def get_aggregate_tags(cls) -> set[str]: - return {"all"} - - # Create a mock scenario class - mock_scenario_class = MagicMock() - mock_scenario_class.get_strategy_class.return_value = MockStrategy - mock_scenario_instance = MagicMock() - mock_scenario_instance.name = "Test Scenario" - mock_scenario_instance.atomic_attack_count = 1 - - # Use AsyncMock for async methods - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=MagicMock()) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_class.__name__ = "TestScenario" - - mock_registry = MagicMock() - mock_registry.get_scenario.return_value = mock_scenario_class - - # Mock the printer with async method - mock_printer = MagicMock() - mock_printer.print_summary_async = AsyncMock() - with patch("pyrit.cli.__main__.ConsoleScenarioResultPrinter", return_value=mock_printer): - await run_scenario_async( - scenario_name="test_scenario", registry=mock_registry, scenario_strategies=["base64"] - ) - - # Verify scenario was instantiated without strategies (they go to initialize_async now) - mock_scenario_class.assert_called_once_with() - - # Verify initialize_async was called with strategies - mock_scenario_instance.initialize_async.assert_called_once() - init_call_kwargs = mock_scenario_instance.initialize_async.call_args[1] - assert "scenario_strategies" in init_call_kwargs - strategies = init_call_kwargs["scenario_strategies"] - assert len(strategies) == 1 - - @pytest.mark.asyncio - async def test_run_scenario_async_with_multiple_strategies(self): - """Test run_scenario_async with multiple strategies.""" - - # Create a mock strategy enum - class MockStrategy(ScenarioStrategy): - ALL = ("all", {"all"}) - Base64 = ("base64", {"test"}) - ROT13 = ("rot13", {"test"}) - Morse = ("morse", {"test"}) - - @classmethod - def get_aggregate_tags(cls) -> set[str]: - return {"all"} - - # Create a mock scenario class - mock_scenario_class = MagicMock() - mock_scenario_class.get_strategy_class.return_value = MockStrategy - mock_scenario_instance = MagicMock() - mock_scenario_instance.name = "Test Scenario" - mock_scenario_instance.atomic_attack_count = 3 - - # Use AsyncMock for async methods - mock_scenario_instance.initialize_async = AsyncMock() - mock_scenario_instance.run_async = AsyncMock(return_value=MagicMock()) - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_class.__name__ = "TestScenario" - - mock_registry = MagicMock() - mock_registry.get_scenario.return_value = mock_scenario_class - - # Mock the printer with async method - mock_printer = MagicMock() - mock_printer.print_summary_async = AsyncMock() - with patch("pyrit.cli.__main__.ConsoleScenarioResultPrinter", return_value=mock_printer): - await run_scenario_async( - scenario_name="test_scenario", registry=mock_registry, scenario_strategies=["base64", "rot13", "morse"] - ) - - # Verify scenario was instantiated without strategies (they go to initialize_async now) - mock_scenario_class.assert_called_once_with() - - # Verify initialize_async was called with strategies - mock_scenario_instance.initialize_async.assert_called_once() - init_call_kwargs = mock_scenario_instance.initialize_async.call_args[1] - assert "scenario_strategies" in init_call_kwargs - strategies = init_call_kwargs["scenario_strategies"] - assert len(strategies) == 3 - - @pytest.mark.asyncio - async def test_run_scenario_async_with_invalid_strategy(self): - """Test run_scenario_async with an invalid strategy name.""" - - # Create a mock strategy enum - class MockStrategy(ScenarioStrategy): - ALL = ("all", {"all"}) - Base64 = ("base64", {"test"}) - - @classmethod - def get_aggregate_tags(cls) -> set[str]: - return {"all"} - - # Create a mock scenario class - mock_scenario_class = MagicMock() - mock_scenario_class.get_strategy_class.return_value = MockStrategy - mock_scenario_class.__name__ = "TestScenario" - - mock_registry = MagicMock() - mock_registry.get_scenario.return_value = mock_scenario_class - - with pytest.raises(ValueError, match="Strategy 'invalid_strategy' not found"): - await run_scenario_async( - scenario_name="test_scenario", registry=mock_registry, scenario_strategies=["invalid_strategy"] - ) - - @pytest.mark.asyncio - async def test_run_scenario_async_initialization_failure(self): - """Test run_scenario_async handles initialization failure.""" - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_scenario_instance.name = "Test Scenario" - - # Make async method raise exception - async def mock_initialize(): - raise Exception("Initialization failed") - - mock_scenario_instance.initialize_async = mock_initialize - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_class.__name__ = "TestScenario" - - mock_registry = MagicMock() - mock_registry.get_scenario.return_value = mock_scenario_class - - with pytest.raises(ValueError, match="Failed to initialize scenario"): - await run_scenario_async(scenario_name="test_scenario", registry=mock_registry, scenario_strategies=None) - - @pytest.mark.asyncio - async def test_run_scenario_async_run_failure(self): - """Test run_scenario_async handles run failure.""" - mock_scenario_class = MagicMock() - mock_scenario_instance = MagicMock() - mock_scenario_instance.name = "Test Scenario" - mock_scenario_instance.atomic_attack_count = 1 - - # Make async methods - async def mock_initialize(): - pass - - async def mock_run(): - raise Exception("Run failed") - - mock_scenario_instance.initialize_async = mock_initialize - mock_scenario_instance.run_async = mock_run - mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_class.__name__ = "TestScenario" - - mock_registry = MagicMock() - mock_registry.get_scenario.return_value = mock_scenario_class - - with pytest.raises(ValueError, match="Failed to run scenario"): - await run_scenario_async(scenario_name="test_scenario", registry=mock_registry, scenario_strategies=None) - - @pytest.mark.asyncio - async def test_run_scenario_async_missing_required_parameters(self): - """Test run_scenario_async with scenario missing required parameters.""" - mock_scenario_class = MagicMock() - mock_scenario_class.side_effect = TypeError("missing 1 required positional argument: 'objective_target'") - mock_scenario_class.__name__ = "TestScenario" - - mock_registry = MagicMock() - mock_registry.get_scenario.return_value = mock_scenario_class - - with pytest.raises(ValueError, match="requires parameters that are not configured"): - await run_scenario_async(scenario_name="test_scenario", registry=mock_registry, scenario_strategies=None) diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py new file mode 100644 index 000000000..9b89251bd --- /dev/null +++ b/tests/unit/cli/test_frontend_core.py @@ -0,0 +1,836 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for the frontend_core module. +""" + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.cli import frontend_core + + +class TestFrontendCore: + """Tests for FrontendCore class.""" + + def test_init_with_defaults(self): + """Test initialization with default parameters.""" + context = frontend_core.FrontendCore() + + assert context._database == frontend_core.SQLITE + assert context._initialization_scripts is None + assert context._initializer_names is None + assert context._log_level == "WARNING" + assert context._initialized is False + + def test_init_with_all_parameters(self): + """Test initialization with all parameters.""" + scripts = [Path("/test/script.py")] + initializers = ["test_init"] + + context = frontend_core.FrontendCore( + database=frontend_core.IN_MEMORY, + initialization_scripts=scripts, + initializer_names=initializers, + log_level="DEBUG", + ) + + assert context._database == frontend_core.IN_MEMORY + assert context._initialization_scripts == scripts + assert context._initializer_names == initializers + assert context._log_level == "DEBUG" + + def test_init_with_invalid_database(self): + """Test initialization with invalid database raises ValueError.""" + with pytest.raises(ValueError, match="Invalid database type"): + frontend_core.FrontendCore(database="InvalidDB") + + def test_init_with_invalid_log_level(self): + """Test initialization with invalid log level raises ValueError.""" + with pytest.raises(ValueError, match="Invalid log level"): + frontend_core.FrontendCore(log_level="INVALID") + + @patch("pyrit.cli.scenario_registry.ScenarioRegistry") + @patch("pyrit.cli.initializer_registry.InitializerRegistry") + @patch("pyrit.setup.initialize_pyrit") + def test_initialize_loads_registries( + self, + mock_init_pyrit: MagicMock, + mock_init_registry: MagicMock, + mock_scenario_registry: MagicMock, + ): + """Test initialize method loads registries.""" + context = frontend_core.FrontendCore() + context.initialize() + + assert context._initialized is True + mock_init_pyrit.assert_called_once() + mock_scenario_registry.assert_called_once() + mock_init_registry.assert_called_once() + + @patch("pyrit.cli.scenario_registry.ScenarioRegistry") + @patch("pyrit.cli.initializer_registry.InitializerRegistry") + @patch("pyrit.setup.initialize_pyrit") + def test_scenario_registry_property_initializes( + self, + mock_init_pyrit: MagicMock, + mock_init_registry: MagicMock, + mock_scenario_registry: MagicMock, + ): + """Test scenario_registry property triggers initialization.""" + context = frontend_core.FrontendCore() + assert context._initialized is False + + registry = context.scenario_registry + + assert context._initialized is True + assert registry is not None + + @patch("pyrit.cli.scenario_registry.ScenarioRegistry") + @patch("pyrit.cli.initializer_registry.InitializerRegistry") + @patch("pyrit.setup.initialize_pyrit") + def test_initializer_registry_property_initializes( + self, + mock_init_pyrit: MagicMock, + mock_init_registry: MagicMock, + mock_scenario_registry: MagicMock, + ): + """Test initializer_registry property triggers initialization.""" + context = frontend_core.FrontendCore() + assert context._initialized is False + + registry = context.initializer_registry + + assert context._initialized is True + assert registry is not None + + +class TestValidationFunctions: + """Tests for validation functions.""" + + def test_validate_database_valid_values(self): + """Test validate_database with valid values.""" + assert frontend_core.validate_database(database=frontend_core.IN_MEMORY) == frontend_core.IN_MEMORY + assert frontend_core.validate_database(database=frontend_core.SQLITE) == frontend_core.SQLITE + assert frontend_core.validate_database(database=frontend_core.AZURE_SQL) == frontend_core.AZURE_SQL + + def test_validate_database_invalid_value(self): + """Test validate_database with invalid value.""" + with pytest.raises(ValueError, match="Invalid database type"): + frontend_core.validate_database(database="InvalidDB") + + def test_validate_log_level_valid_values(self): + """Test validate_log_level with valid values.""" + assert frontend_core.validate_log_level(log_level="DEBUG") == "DEBUG" + assert frontend_core.validate_log_level(log_level="INFO") == "INFO" + assert frontend_core.validate_log_level(log_level="warning") == "WARNING" # Case-insensitive + assert frontend_core.validate_log_level(log_level="error") == "ERROR" + assert frontend_core.validate_log_level(log_level="CRITICAL") == "CRITICAL" + + def test_validate_log_level_invalid_value(self): + """Test validate_log_level with invalid value.""" + with pytest.raises(ValueError, match="Invalid log level"): + frontend_core.validate_log_level(log_level="INVALID") + + def test_validate_integer_valid(self): + """Test validate_integer with valid values.""" + assert frontend_core.validate_integer("42") == 42 + assert frontend_core.validate_integer("0") == 0 + assert frontend_core.validate_integer("-5") == -5 + + def test_validate_integer_with_min_value(self): + """Test validate_integer with min_value constraint.""" + assert frontend_core.validate_integer("5", min_value=1) == 5 + assert frontend_core.validate_integer("1", min_value=1) == 1 + + def test_validate_integer_below_min_value(self): + """Test validate_integer below min_value raises ValueError.""" + with pytest.raises(ValueError, match="must be at least"): + frontend_core.validate_integer("0", min_value=1) + + def test_validate_integer_invalid_string(self): + """Test validate_integer with non-integer string.""" + with pytest.raises(ValueError, match="must be an integer"): + frontend_core.validate_integer("not_a_number") + + def test_validate_integer_custom_name(self): + """Test validate_integer with custom parameter name.""" + with pytest.raises(ValueError, match="max_retries must be an integer"): + frontend_core.validate_integer("invalid", name="max_retries") + + def test_positive_int_valid(self): + """Test positive_int with valid values.""" + assert frontend_core.positive_int("1") == 1 + assert frontend_core.positive_int("100") == 100 + + def test_positive_int_zero(self): + """Test positive_int with zero raises error.""" + import argparse + + with pytest.raises(argparse.ArgumentTypeError): + frontend_core.positive_int("0") + + def test_positive_int_negative(self): + """Test positive_int with negative value raises error.""" + import argparse + + with pytest.raises(argparse.ArgumentTypeError): + frontend_core.positive_int("-1") + + def test_non_negative_int_valid(self): + """Test non_negative_int with valid values.""" + assert frontend_core.non_negative_int("0") == 0 + assert frontend_core.non_negative_int("5") == 5 + + def test_non_negative_int_negative(self): + """Test non_negative_int with negative value raises error.""" + import argparse + + with pytest.raises(argparse.ArgumentTypeError): + frontend_core.non_negative_int("-1") + + def test_validate_database_argparse(self): + """Test validate_database_argparse wrapper.""" + assert frontend_core.validate_database_argparse(frontend_core.IN_MEMORY) == frontend_core.IN_MEMORY + + import argparse + + with pytest.raises(argparse.ArgumentTypeError): + frontend_core.validate_database_argparse("InvalidDB") + + def test_validate_log_level_argparse(self): + """Test validate_log_level_argparse wrapper.""" + assert frontend_core.validate_log_level_argparse("DEBUG") == "DEBUG" + + import argparse + + with pytest.raises(argparse.ArgumentTypeError): + frontend_core.validate_log_level_argparse("INVALID") + + +class TestParseMemoryLabels: + """Tests for parse_memory_labels function.""" + + def test_parse_memory_labels_valid(self): + """Test parsing valid JSON labels.""" + json_str = '{"key1": "value1", "key2": "value2"}' + result = frontend_core.parse_memory_labels(json_string=json_str) + + assert result == {"key1": "value1", "key2": "value2"} + + def test_parse_memory_labels_empty(self): + """Test parsing empty JSON object.""" + result = frontend_core.parse_memory_labels(json_string="{}") + assert result == {} + + def test_parse_memory_labels_invalid_json(self): + """Test parsing invalid JSON raises ValueError.""" + with pytest.raises(ValueError, match="Invalid JSON"): + frontend_core.parse_memory_labels(json_string="not valid json") + + def test_parse_memory_labels_not_dict(self): + """Test parsing JSON array raises ValueError.""" + with pytest.raises(ValueError, match="must be a JSON object"): + frontend_core.parse_memory_labels(json_string='["array", "not", "dict"]') + + def test_parse_memory_labels_non_string_key(self): + """Test parsing with non-string values raises ValueError.""" + with pytest.raises(ValueError, match="All label keys and values must be strings"): + frontend_core.parse_memory_labels(json_string='{"key": 123}') + + +class TestResolveInitializationScripts: + """Tests for resolve_initialization_scripts function.""" + + @patch("pyrit.cli.initializer_registry.InitializerRegistry.resolve_script_paths") + def test_resolve_initialization_scripts(self, mock_resolve: MagicMock): + """Test resolve_initialization_scripts calls InitializerRegistry.""" + mock_resolve.return_value = [Path("/test/script.py")] + + result = frontend_core.resolve_initialization_scripts(script_paths=["script.py"]) + + mock_resolve.assert_called_once_with(script_paths=["script.py"]) + assert result == [Path("/test/script.py")] + + +class TestGetDefaultInitializerDiscoveryPath: + """Tests for get_default_initializer_discovery_path function.""" + + def test_get_default_initializer_discovery_path(self): + """Test get_default_initializer_discovery_path returns correct path.""" + path = frontend_core.get_default_initializer_discovery_path() + + assert isinstance(path, Path) + assert path.parts[-3:] == ("setup", "initializers", "scenarios") + + +class TestListFunctions: + """Tests for list_scenarios and list_initializers functions.""" + + def test_list_scenarios(self): + """Test list_scenarios returns scenarios from registry.""" + mock_registry = MagicMock() + mock_registry.list_scenarios.return_value = [{"name": "test_scenario"}] + + context = frontend_core.FrontendCore() + context._scenario_registry = mock_registry + context._initialized = True + + result = frontend_core.list_scenarios(context=context) + + assert result == [{"name": "test_scenario"}] + mock_registry.list_scenarios.assert_called_once() + + def test_list_initializers_without_discovery_path(self): + """Test list_initializers without discovery path.""" + mock_registry = MagicMock() + mock_registry.list_initializers.return_value = [{"name": "test_init"}] + + context = frontend_core.FrontendCore() + context._initializer_registry = mock_registry + context._initialized = True + + result = frontend_core.list_initializers(context=context) + + assert result == [{"name": "test_init"}] + mock_registry.list_initializers.assert_called_once() + + @patch("pyrit.cli.initializer_registry.InitializerRegistry") + def test_list_initializers_with_discovery_path(self, mock_init_registry_class: MagicMock): + """Test list_initializers with discovery path.""" + mock_registry = MagicMock() + mock_registry.list_initializers.return_value = [{"name": "custom_init"}] + mock_init_registry_class.return_value = mock_registry + + context = frontend_core.FrontendCore() + discovery_path = Path("/custom/path") + + result = frontend_core.list_initializers(context=context, discovery_path=discovery_path) + + mock_init_registry_class.assert_called_once_with(discovery_path=discovery_path) + assert result == [{"name": "custom_init"}] + + +class TestPrintFunctions: + """Tests for print functions.""" + + def test_print_scenarios_list_with_scenarios(self, capsys): + """Test print_scenarios_list with scenarios.""" + context = frontend_core.FrontendCore() + mock_registry = MagicMock() + mock_registry.list_scenarios.return_value = [ + { + "name": "test_scenario", + "class_name": "TestScenario", + "description": "Test description", + } + ] + context._scenario_registry = mock_registry + context._initialized = True + + result = frontend_core.print_scenarios_list(context=context) + + assert result == 0 + captured = capsys.readouterr() + assert "Available Scenarios" in captured.out + assert "test_scenario" in captured.out + + def test_print_scenarios_list_empty(self, capsys): + """Test print_scenarios_list with no scenarios.""" + context = frontend_core.FrontendCore() + mock_registry = MagicMock() + mock_registry.list_scenarios.return_value = [] + context._scenario_registry = mock_registry + context._initialized = True + + result = frontend_core.print_scenarios_list(context=context) + + assert result == 0 + captured = capsys.readouterr() + assert "No scenarios found" in captured.out + + def test_print_initializers_list_with_initializers(self, capsys): + """Test print_initializers_list with initializers.""" + context = frontend_core.FrontendCore() + mock_registry = MagicMock() + mock_registry.list_initializers.return_value = [ + { + "name": "test_init", + "class_name": "TestInit", + "initializer_name": "test", + "execution_order": 100, + } + ] + context._initializer_registry = mock_registry + context._initialized = True + + result = frontend_core.print_initializers_list(context=context) + + assert result == 0 + captured = capsys.readouterr() + assert "Available Initializers" in captured.out + assert "test_init" in captured.out + + def test_print_initializers_list_empty(self, capsys): + """Test print_initializers_list with no initializers.""" + context = frontend_core.FrontendCore() + mock_registry = MagicMock() + mock_registry.list_initializers.return_value = [] + context._initializer_registry = mock_registry + context._initialized = True + + result = frontend_core.print_initializers_list(context=context) + + assert result == 0 + captured = capsys.readouterr() + assert "No initializers found" in captured.out + + +class TestFormatFunctions: + """Tests for format_scenario_info and format_initializer_info.""" + + def test_format_scenario_info_basic(self, capsys): + """Test format_scenario_info with basic info.""" + scenario_info = { + "name": "test_scenario", + "class_name": "TestScenario", + } + + frontend_core.format_scenario_info(scenario_info=scenario_info) + + captured = capsys.readouterr() + assert "test_scenario" in captured.out + assert "TestScenario" in captured.out + + def test_format_scenario_info_with_description(self, capsys): + """Test format_scenario_info with description.""" + scenario_info = { + "name": "test_scenario", + "class_name": "TestScenario", + "description": "This is a test scenario", + } + + frontend_core.format_scenario_info(scenario_info=scenario_info) + + captured = capsys.readouterr() + assert "This is a test scenario" in captured.out + + def test_format_scenario_info_with_strategies(self, capsys): + """Test format_scenario_info with strategies.""" + scenario_info = { + "name": "test_scenario", + "class_name": "TestScenario", + "all_strategies": ["strategy1", "strategy2"], + "default_strategy": "strategy1", + } + + frontend_core.format_scenario_info(scenario_info=scenario_info) + + captured = capsys.readouterr() + assert "strategy1" in captured.out + assert "strategy2" in captured.out + assert "Default Strategy" in captured.out + + def test_format_initializer_info_basic(self, capsys): + """Test format_initializer_info with basic info.""" + from pyrit.cli.initializer_registry import InitializerInfo + + initializer_info: InitializerInfo = { + "name": "test_init", + "class_name": "TestInit", + "initializer_name": "test", + "description": "", + "required_env_vars": [], + "execution_order": 100, + } + + frontend_core.format_initializer_info(initializer_info=initializer_info) + + captured = capsys.readouterr() + assert "test_init" in captured.out + assert "TestInit" in captured.out + assert "100" in captured.out + + def test_format_initializer_info_with_env_vars(self, capsys): + """Test format_initializer_info with environment variables.""" + from pyrit.cli.initializer_registry import InitializerInfo + + initializer_info: InitializerInfo = { + "name": "test_init", + "class_name": "TestInit", + "initializer_name": "test", + "description": "", + "required_env_vars": ["VAR1", "VAR2"], + "execution_order": 100, + } + + frontend_core.format_initializer_info(initializer_info=initializer_info) + + captured = capsys.readouterr() + assert "VAR1" in captured.out + assert "VAR2" in captured.out + + def test_format_initializer_info_with_description(self, capsys): + """Test format_initializer_info with description.""" + from pyrit.cli.initializer_registry import InitializerInfo + + initializer_info: InitializerInfo = { + "name": "test_init", + "class_name": "TestInit", + "initializer_name": "test", + "description": "Test description", + "required_env_vars": [], + "execution_order": 100, + } + + frontend_core.format_initializer_info(initializer_info=initializer_info) + + captured = capsys.readouterr() + assert "Test description" in captured.out + + +class TestParseRunArguments: + """Tests for parse_run_arguments function.""" + + def test_parse_run_arguments_basic(self): + """Test parsing basic scenario name.""" + result = frontend_core.parse_run_arguments(args_string="test_scenario") + + assert result["scenario_name"] == "test_scenario" + assert result["initializers"] is None + assert result["scenario_strategies"] is None + + def test_parse_run_arguments_with_initializers(self): + """Test parsing with initializers.""" + result = frontend_core.parse_run_arguments(args_string="test_scenario --initializers init1 init2") + + assert result["scenario_name"] == "test_scenario" + assert result["initializers"] == ["init1", "init2"] + + def test_parse_run_arguments_with_strategies(self): + """Test parsing with strategies.""" + result = frontend_core.parse_run_arguments(args_string="test_scenario --strategies s1 s2") + + assert result["scenario_strategies"] == ["s1", "s2"] + + def test_parse_run_arguments_with_short_strategies(self): + """Test parsing with -s flag.""" + result = frontend_core.parse_run_arguments(args_string="test_scenario -s s1 s2") + + assert result["scenario_strategies"] == ["s1", "s2"] + + def test_parse_run_arguments_with_max_concurrency(self): + """Test parsing with max-concurrency.""" + result = frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency 5") + + assert result["max_concurrency"] == 5 + + def test_parse_run_arguments_with_max_retries(self): + """Test parsing with max-retries.""" + result = frontend_core.parse_run_arguments(args_string="test_scenario --max-retries 3") + + assert result["max_retries"] == 3 + + def test_parse_run_arguments_with_memory_labels(self): + """Test parsing with memory-labels.""" + result = frontend_core.parse_run_arguments(args_string='test_scenario --memory-labels {"key":"value"}') + + assert result["memory_labels"] == {"key": "value"} + + def test_parse_run_arguments_with_database(self): + """Test parsing with database override.""" + result = frontend_core.parse_run_arguments(args_string=f"test_scenario --database {frontend_core.IN_MEMORY}") + + assert result["database"] == frontend_core.IN_MEMORY + + def test_parse_run_arguments_with_log_level(self): + """Test parsing with log-level override.""" + result = frontend_core.parse_run_arguments(args_string="test_scenario --log-level DEBUG") + + assert result["log_level"] == "DEBUG" + + def test_parse_run_arguments_with_initialization_scripts(self): + """Test parsing with initialization-scripts.""" + result = frontend_core.parse_run_arguments( + args_string="test_scenario --initialization-scripts script1.py script2.py" + ) + + assert result["initialization_scripts"] == ["script1.py", "script2.py"] + + def test_parse_run_arguments_complex(self): + """Test parsing complex argument combination.""" + args = "test_scenario --initializers init1 --strategies s1 s2 --max-concurrency 10" + result = frontend_core.parse_run_arguments(args_string=args) + + assert result["scenario_name"] == "test_scenario" + assert result["initializers"] == ["init1"] + assert result["scenario_strategies"] == ["s1", "s2"] + assert result["max_concurrency"] == 10 + + def test_parse_run_arguments_empty_raises(self): + """Test parsing empty string raises ValueError.""" + with pytest.raises(ValueError, match="No scenario name provided"): + frontend_core.parse_run_arguments(args_string="") + + def test_parse_run_arguments_invalid_max_concurrency(self): + """Test parsing with invalid max-concurrency.""" + with pytest.raises(ValueError): + frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency 0") + + def test_parse_run_arguments_invalid_max_retries(self): + """Test parsing with invalid max-retries.""" + with pytest.raises(ValueError): + frontend_core.parse_run_arguments(args_string="test_scenario --max-retries -1") + + def test_parse_run_arguments_missing_value(self): + """Test parsing with missing argument value.""" + with pytest.raises(ValueError, match="requires a value"): + frontend_core.parse_run_arguments(args_string="test_scenario --max-concurrency") + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("patch_central_database") +class TestRunScenarioAsync: + """Tests for run_scenario_async function.""" + + @patch("pyrit.setup.initialize_pyrit") + @patch("pyrit.scenarios.printer.console_printer.ConsoleScenarioResultPrinter") + async def test_run_scenario_async_basic( + self, + mock_printer_class: MagicMock, + mock_init_pyrit: MagicMock, + ): + """Test running a basic scenario.""" + # Mock context + context = frontend_core.FrontendCore() + mock_scenario_registry = MagicMock() + mock_scenario_class = MagicMock() + mock_scenario_instance = MagicMock() + mock_result = MagicMock() + mock_printer = MagicMock() + mock_printer.print_summary_async = AsyncMock() + + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) + mock_scenario_class.return_value = mock_scenario_instance + mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_printer_class.return_value = mock_printer + + context._scenario_registry = mock_scenario_registry + context._initializer_registry = MagicMock() + context._initialized = True + + # Run scenario + result = await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + ) + + assert result == mock_result + # Verify scenario was instantiated with no arguments (runtime params go to initialize_async) + mock_scenario_class.assert_called_once_with() + mock_scenario_instance.initialize_async.assert_called_once_with() + mock_scenario_instance.run_async.assert_called_once() + mock_printer.print_summary_async.assert_called_once_with(mock_result) + + @patch("pyrit.setup.initialize_pyrit") + async def test_run_scenario_async_not_found(self, mock_init_pyrit: MagicMock): + """Test running non-existent scenario raises ValueError.""" + context = frontend_core.FrontendCore() + mock_scenario_registry = MagicMock() + mock_scenario_registry.get_scenario.return_value = None + mock_scenario_registry.get_scenario_names.return_value = ["other_scenario"] + + context._scenario_registry = mock_scenario_registry + context._initializer_registry = MagicMock() + context._initialized = True + + with pytest.raises(ValueError, match="Scenario 'test_scenario' not found"): + await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + ) + + @patch("pyrit.setup.initialize_pyrit") + @patch("pyrit.scenarios.printer.console_printer.ConsoleScenarioResultPrinter") + async def test_run_scenario_async_with_strategies( + self, + mock_printer_class: MagicMock, + mock_init_pyrit: MagicMock, + ): + """Test running scenario with strategies.""" + context = frontend_core.FrontendCore() + mock_scenario_registry = MagicMock() + mock_scenario_class = MagicMock() + mock_scenario_instance = MagicMock() + mock_result = MagicMock() + mock_printer = MagicMock() + mock_printer.print_summary_async = AsyncMock() + + # Mock strategy enum + from enum import Enum + + class MockStrategy(Enum): + strategy1 = "strategy1" + + mock_scenario_class.get_strategy_class.return_value = MockStrategy + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) + mock_scenario_class.return_value = mock_scenario_instance + mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_printer_class.return_value = mock_printer + + context._scenario_registry = mock_scenario_registry + context._initializer_registry = MagicMock() + context._initialized = True + + # Run with strategies + await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + scenario_strategies=["strategy1"], + ) + + # Verify scenario was instantiated with no arguments + mock_scenario_class.assert_called_once_with() + # Verify strategy was passed to initialize_async + call_kwargs = mock_scenario_instance.initialize_async.call_args[1] + assert "scenario_strategies" in call_kwargs + + @patch("pyrit.setup.initialize_pyrit") + @patch("pyrit.scenarios.printer.console_printer.ConsoleScenarioResultPrinter") + async def test_run_scenario_async_with_initializers( + self, + mock_printer_class: MagicMock, + mock_init_pyrit: MagicMock, + ): + """Test running scenario with initializers.""" + context = frontend_core.FrontendCore(initializer_names=["test_init"]) + mock_scenario_registry = MagicMock() + mock_initializer_registry = MagicMock() + mock_scenario_class = MagicMock() + mock_scenario_instance = MagicMock() + mock_result = MagicMock() + mock_printer = MagicMock() + mock_printer.print_summary_async = AsyncMock() + + mock_initializer_class = MagicMock() + mock_initializer_registry.get_initializer_class.return_value = mock_initializer_class + + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) + mock_scenario_class.return_value = mock_scenario_instance + mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_printer_class.return_value = mock_printer + + context._scenario_registry = mock_scenario_registry + context._initializer_registry = mock_initializer_registry + context._initialized = True + + # Run with initializers + await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + ) + + # Verify initializer was retrieved + mock_initializer_registry.get_initializer_class.assert_called_once_with(name="test_init") + + @patch("pyrit.setup.initialize_pyrit") + @patch("pyrit.scenarios.printer.console_printer.ConsoleScenarioResultPrinter") + async def test_run_scenario_async_with_max_concurrency( + self, + mock_printer_class: MagicMock, + mock_init_pyrit: MagicMock, + ): + """Test running scenario with max_concurrency.""" + context = frontend_core.FrontendCore() + mock_scenario_registry = MagicMock() + mock_scenario_class = MagicMock() + mock_scenario_instance = MagicMock() + mock_result = MagicMock() + mock_printer = MagicMock() + mock_printer.print_summary_async = AsyncMock() + + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) + mock_scenario_class.return_value = mock_scenario_instance + mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_printer_class.return_value = mock_printer + + context._scenario_registry = mock_scenario_registry + context._initializer_registry = MagicMock() + context._initialized = True + + # Run with max_concurrency + await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + max_concurrency=5, + ) + + # Verify scenario was instantiated with no arguments + mock_scenario_class.assert_called_once_with() + # Verify max_concurrency was passed to initialize_async + call_kwargs = mock_scenario_instance.initialize_async.call_args[1] + assert call_kwargs["max_concurrency"] == 5 + + @patch("pyrit.setup.initialize_pyrit") + @patch("pyrit.scenarios.printer.console_printer.ConsoleScenarioResultPrinter") + async def test_run_scenario_async_without_print_summary( + self, + mock_printer_class: MagicMock, + mock_init_pyrit: MagicMock, + ): + """Test running scenario without printing summary.""" + context = frontend_core.FrontendCore() + mock_scenario_registry = MagicMock() + mock_scenario_class = MagicMock() + mock_scenario_instance = MagicMock() + mock_result = MagicMock() + mock_printer = MagicMock() + + mock_scenario_instance.initialize_async = AsyncMock() + mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) + mock_scenario_class.return_value = mock_scenario_instance + mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_printer_class.return_value = mock_printer + + context._scenario_registry = mock_scenario_registry + context._initializer_registry = MagicMock() + context._initialized = True + + # Run without printing + await frontend_core.run_scenario_async( + scenario_name="test_scenario", + context=context, + print_summary=False, + ) + + # Verify printer was not called + assert mock_printer.print_summary_async.call_count == 0 + + +class TestArgHelp: + """Tests for frontend_core.ARG_HELP dictionary.""" + + def test_arg_help_contains_all_keys(self): + """Test frontend_core.ARG_HELP contains expected keys.""" + expected_keys = [ + "initializers", + "initialization_scripts", + "scenario_strategies", + "max_concurrency", + "max_retries", + "memory_labels", + "database", + "log_level", + ] + + for key in expected_keys: + assert key in frontend_core.ARG_HELP + assert isinstance(frontend_core.ARG_HELP[key], str) + assert len(frontend_core.ARG_HELP[key]) > 0 diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py new file mode 100644 index 000000000..7e7d3bb27 --- /dev/null +++ b/tests/unit/cli/test_pyrit_scan.py @@ -0,0 +1,421 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for the pyrit_scan CLI module. +""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from pyrit.cli import pyrit_scan + + +class TestParseArgs: + """Tests for parse_args function.""" + + def test_parse_args_list_scenarios(self): + """Test parsing --list-scenarios flag.""" + args = pyrit_scan.parse_args(["--list-scenarios"]) + + assert args.list_scenarios is True + assert args.scenario_name is None + + def test_parse_args_list_initializers(self): + """Test parsing --list-initializers flag.""" + args = pyrit_scan.parse_args(["--list-initializers"]) + + assert args.list_initializers is True + assert args.scenario_name is None + + def test_parse_args_scenario_name_only(self): + """Test parsing scenario name without options.""" + args = pyrit_scan.parse_args(["test_scenario"]) + + assert args.scenario_name == "test_scenario" + assert args.database == "SQLite" + assert args.log_level == "WARNING" + + def test_parse_args_with_database(self): + """Test parsing with database option.""" + args = pyrit_scan.parse_args(["test_scenario", "--database", "InMemory"]) + + assert args.database == "InMemory" + + def test_parse_args_with_log_level(self): + """Test parsing with log-level option.""" + args = pyrit_scan.parse_args(["test_scenario", "--log-level", "DEBUG"]) + + assert args.log_level == "DEBUG" + + def test_parse_args_with_initializers(self): + """Test parsing with initializers.""" + args = pyrit_scan.parse_args(["test_scenario", "--initializers", "init1", "init2"]) + + assert args.initializers == ["init1", "init2"] + + def test_parse_args_with_initialization_scripts(self): + """Test parsing with initialization-scripts.""" + args = pyrit_scan.parse_args(["test_scenario", "--initialization-scripts", "script1.py", "script2.py"]) + + assert args.initialization_scripts == ["script1.py", "script2.py"] + + def test_parse_args_with_strategies(self): + """Test parsing with strategies.""" + args = pyrit_scan.parse_args(["test_scenario", "--strategies", "s1", "s2"]) + + assert args.scenario_strategies == ["s1", "s2"] + + def test_parse_args_with_strategies_short_flag(self): + """Test parsing with -s flag.""" + args = pyrit_scan.parse_args(["test_scenario", "-s", "s1", "s2"]) + + assert args.scenario_strategies == ["s1", "s2"] + + def test_parse_args_with_max_concurrency(self): + """Test parsing with max-concurrency.""" + args = pyrit_scan.parse_args(["test_scenario", "--max-concurrency", "5"]) + + assert args.max_concurrency == 5 + + def test_parse_args_with_max_retries(self): + """Test parsing with max-retries.""" + args = pyrit_scan.parse_args(["test_scenario", "--max-retries", "3"]) + + assert args.max_retries == 3 + + def test_parse_args_with_memory_labels(self): + """Test parsing with memory-labels.""" + args = pyrit_scan.parse_args(["test_scenario", "--memory-labels", '{"key":"value"}']) + + assert args.memory_labels == '{"key":"value"}' + + def test_parse_args_complex_command(self): + """Test parsing complex command with multiple options.""" + args = pyrit_scan.parse_args( + [ + "encoding_scenario", + "--database", + "InMemory", + "--log-level", + "INFO", + "--initializers", + "openai_target", + "--strategies", + "base64", + "rot13", + "--max-concurrency", + "10", + "--max-retries", + "5", + "--memory-labels", + '{"env":"test"}', + ] + ) + + assert args.scenario_name == "encoding_scenario" + assert args.database == "InMemory" + assert args.log_level == "INFO" + assert args.initializers == ["openai_target"] + assert args.scenario_strategies == ["base64", "rot13"] + assert args.max_concurrency == 10 + assert args.max_retries == 5 + assert args.memory_labels == '{"env":"test"}' + + def test_parse_args_invalid_database(self): + """Test parsing with invalid database raises error.""" + with pytest.raises(SystemExit): + pyrit_scan.parse_args(["test_scenario", "--database", "InvalidDB"]) + + def test_parse_args_invalid_log_level(self): + """Test parsing with invalid log level raises error.""" + with pytest.raises(SystemExit): + pyrit_scan.parse_args(["test_scenario", "--log-level", "INVALID"]) + + def test_parse_args_invalid_max_concurrency(self): + """Test parsing with invalid max-concurrency raises error.""" + with pytest.raises(SystemExit): + pyrit_scan.parse_args(["test_scenario", "--max-concurrency", "0"]) + + def test_parse_args_invalid_max_retries(self): + """Test parsing with invalid max-retries raises error.""" + with pytest.raises(SystemExit): + pyrit_scan.parse_args(["test_scenario", "--max-retries", "-1"]) + + def test_parse_args_help_flag(self): + """Test parsing --help flag exits.""" + with pytest.raises(SystemExit) as exc_info: + pyrit_scan.parse_args(["--help"]) + + assert exc_info.value.code == 0 + + +class TestMain: + """Tests for main function.""" + + @patch("pyrit.cli.frontend_core.print_scenarios_list") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_scenarios(self, mock_frontend_core: MagicMock, mock_print_scenarios: MagicMock): + """Test main with --list-scenarios flag.""" + mock_print_scenarios.return_value = 0 + + result = pyrit_scan.main(["--list-scenarios"]) + + assert result == 0 + mock_print_scenarios.assert_called_once() + mock_frontend_core.assert_called_once() + + @patch("pyrit.cli.frontend_core.print_initializers_list") + @patch("pyrit.cli.frontend_core.FrontendCore") + @patch("pyrit.cli.frontend_core.get_default_initializer_discovery_path") + def test_main_list_initializers( + self, + mock_get_path: MagicMock, + mock_frontend_core: MagicMock, + mock_print_initializers: MagicMock, + ): + """Test main with --list-initializers flag.""" + mock_print_initializers.return_value = 0 + mock_get_path.return_value = Path("/test/path") + + result = pyrit_scan.main(["--list-initializers"]) + + assert result == 0 + mock_print_initializers.assert_called_once() + mock_get_path.assert_called_once() + + @patch("pyrit.cli.frontend_core.print_scenarios_list") + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_list_scenarios_with_scripts( + self, + mock_frontend_core: MagicMock, + mock_resolve_scripts: MagicMock, + mock_print_scenarios: MagicMock, + ): + """Test main with --list-scenarios and --initialization-scripts.""" + mock_resolve_scripts.return_value = [Path("/test/script.py")] + mock_print_scenarios.return_value = 0 + + result = pyrit_scan.main(["--list-scenarios", "--initialization-scripts", "script.py"]) + + assert result == 0 + mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) + mock_print_scenarios.assert_called_once() + + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + def test_main_list_scenarios_with_missing_script(self, mock_resolve_scripts: MagicMock): + """Test main with --list-scenarios and missing script file.""" + mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + + result = pyrit_scan.main(["--list-scenarios", "--initialization-scripts", "missing.py"]) + + assert result == 1 + + def test_main_no_scenario_specified(self, capsys): + """Test main without scenario name.""" + result = pyrit_scan.main([]) + + assert result == 1 + captured = capsys.readouterr() + assert "No scenario specified" in captured.out + + @patch("pyrit.cli.pyrit_scan.asyncio.run") + @patch("pyrit.cli.frontend_core.run_scenario_async") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_run_scenario_basic( + self, + mock_frontend_core: MagicMock, + mock_run_scenario: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test main running a basic scenario.""" + result = pyrit_scan.main(["test_scenario", "--initializers", "test_init"]) + + assert result == 0 + mock_asyncio_run.assert_called_once() + + @patch("pyrit.cli.pyrit_scan.asyncio.run") + @patch("pyrit.cli.frontend_core.run_scenario_async") + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_run_scenario_with_scripts( + self, + mock_frontend_core: MagicMock, + mock_resolve_scripts: MagicMock, + mock_run_scenario: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test main running scenario with initialization scripts.""" + mock_resolve_scripts.return_value = [Path("/test/script.py")] + + result = pyrit_scan.main(["test_scenario", "--initialization-scripts", "script.py"]) + + assert result == 0 + mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) + mock_asyncio_run.assert_called_once() + + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + def test_main_run_scenario_with_missing_script(self, mock_resolve_scripts: MagicMock): + """Test main with missing initialization script.""" + mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + + result = pyrit_scan.main(["test_scenario", "--initialization-scripts", "missing.py"]) + + assert result == 1 + + @patch("pyrit.cli.pyrit_scan.asyncio.run") + @patch("pyrit.cli.frontend_core.run_scenario_async") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_run_scenario_with_all_options( + self, + mock_frontend_core: MagicMock, + mock_run_scenario: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test main with all scenario options.""" + result = pyrit_scan.main( + [ + "test_scenario", + "--database", + "InMemory", + "--log-level", + "DEBUG", + "--initializers", + "init1", + "init2", + "--strategies", + "s1", + "s2", + "--max-concurrency", + "10", + "--max-retries", + "5", + "--memory-labels", + '{"key":"value"}', + ] + ) + + assert result == 0 + mock_asyncio_run.assert_called_once() + + # Verify FrontendCore was called with correct args + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["database"] == "InMemory" + assert call_kwargs["log_level"] == "DEBUG" + assert call_kwargs["initializer_names"] == ["init1", "init2"] + + @patch("pyrit.cli.pyrit_scan.asyncio.run") + @patch("pyrit.cli.frontend_core.parse_memory_labels") + @patch("pyrit.cli.frontend_core.run_scenario_async") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_run_scenario_with_memory_labels( + self, + mock_frontend_core: MagicMock, + mock_run_scenario: MagicMock, + mock_parse_labels: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test main with memory labels parsing.""" + mock_parse_labels.return_value = {"key": "value"} + + result = pyrit_scan.main(["test_scenario", "--initializers", "test_init", "--memory-labels", '{"key":"value"}']) + + assert result == 0 + mock_parse_labels.assert_called_once_with(json_string='{"key":"value"}') + + @patch("pyrit.cli.pyrit_scan.asyncio.run") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_run_scenario_with_exception( + self, + mock_frontend_core: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test main handles exceptions during scenario run.""" + mock_asyncio_run.side_effect = ValueError("Test error") + + result = pyrit_scan.main(["test_scenario", "--initializers", "test_init"]) + + assert result == 1 + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_database_defaults_to_sqlite(self, mock_frontend_core: MagicMock): + """Test main uses SQLite as default database.""" + pyrit_scan.main(["--list-scenarios"]) + + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["database"] == "SQLite" + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_log_level_defaults_to_warning(self, mock_frontend_core: MagicMock): + """Test main uses WARNING as default log level.""" + pyrit_scan.main(["--list-scenarios"]) + + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["log_level"] == "WARNING" + + def test_main_with_invalid_args(self): + """Test main with invalid arguments.""" + result = pyrit_scan.main(["--invalid-flag"]) + + assert result == 2 # argparse returns 2 for invalid arguments + + @patch("builtins.print") + def test_main_prints_startup_message(self, mock_print: MagicMock): + """Test main prints startup message.""" + pyrit_scan.main(["--list-scenarios"]) + + # Check that "Starting PyRIT..." was printed + calls = [str(call_obj) for call_obj in mock_print.call_args_list] + assert any("Starting PyRIT" in str(call_obj) for call_obj in calls) + + @patch("pyrit.cli.pyrit_scan.asyncio.run") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_run_scenario_calls_run_scenario_async( + self, + mock_frontend_core: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test main properly calls run_scenario_async.""" + pyrit_scan.main(["test_scenario", "--initializers", "test_init", "--strategies", "s1"]) + + # Verify asyncio.run was called with run_scenario_async + assert mock_asyncio_run.call_count == 1 + assert mock_asyncio_run.call_count == 1 + + +class TestMainIntegration: + """Integration-style tests for main function.""" + + @patch("pyrit.cli.frontend_core.print_scenarios_list") + @patch("pyrit.cli.scenario_registry.ScenarioRegistry") + @patch("pyrit.setup.initialize_pyrit") + def test_main_list_scenarios_integration( + self, + mock_init_pyrit: MagicMock, + mock_scenario_registry: MagicMock, + mock_print_scenarios: MagicMock, + ): + """Test main --list-scenarios with minimal mocking.""" + mock_print_scenarios.return_value = 0 + + result = pyrit_scan.main(["--list-scenarios"]) + + assert result == 0 + + @patch("pyrit.cli.frontend_core.print_initializers_list") + @patch("pyrit.cli.frontend_core.get_default_initializer_discovery_path") + def test_main_list_initializers_integration( + self, + mock_get_path: MagicMock, + mock_print_initializers: MagicMock, + ): + """Test main --list-initializers with minimal mocking.""" + mock_get_path.return_value = Path("/test/path") + mock_print_initializers.return_value = 0 + + result = pyrit_scan.main(["--list-initializers"]) + + assert result == 0 diff --git a/tests/unit/cli/test_pyrit_shell.py b/tests/unit/cli/test_pyrit_shell.py new file mode 100644 index 000000000..05ac8a46d --- /dev/null +++ b/tests/unit/cli/test_pyrit_shell.py @@ -0,0 +1,737 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Unit tests for the pyrit_shell CLI module. +""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +from pyrit.cli import pyrit_shell + + +class TestPyRITShell: + """Tests for PyRITShell class.""" + + def test_init(self): + """Test PyRITShell initialization.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + assert shell.context == mock_context + assert shell.default_database == "SQLite" + assert shell.default_log_level == "WARNING" + assert shell._scenario_history == [] + # Initialize is called in a background thread, so we need to wait for it + shell._init_complete.wait(timeout=2) + mock_context.initialize.assert_called_once() + + def test_prompt_and_intro(self): + """Test shell prompt and intro are set.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + assert shell.prompt == "pyrit> " + assert shell.intro is not None + assert "Interactive Shell" in str(shell.intro) + + @patch("pyrit.cli.frontend_core.print_scenarios_list") + def test_do_list_scenarios(self, mock_print_scenarios: MagicMock): + """Test do_list_scenarios command.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_list_scenarios("") + + mock_print_scenarios.assert_called_once_with(context=mock_context) + + @patch("pyrit.cli.frontend_core.print_scenarios_list") + def test_do_list_scenarios_with_exception(self, mock_print_scenarios: MagicMock, capsys): + """Test do_list_scenarios handles exceptions.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + mock_print_scenarios.side_effect = ValueError("Test error") + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_list_scenarios("") + + captured = capsys.readouterr() + assert "Error listing scenarios" in captured.out + + @patch("pyrit.cli.frontend_core.print_initializers_list") + def test_do_list_initializers(self, mock_print_initializers: MagicMock): + """Test do_list_initializers command.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_list_initializers("") + + mock_print_initializers.assert_called_once_with(context=mock_context, discovery_path=None) + + @patch("pyrit.cli.frontend_core.print_initializers_list") + def test_do_list_initializers_with_path(self, mock_print_initializers: MagicMock): + """Test do_list_initializers with custom path.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_list_initializers("/custom/path") + + assert mock_print_initializers.call_count == 1 + call_kwargs = mock_print_initializers.call_args[1] + assert isinstance(call_kwargs["discovery_path"], Path) + + @patch("pyrit.cli.frontend_core.print_initializers_list") + def test_do_list_initializers_with_exception(self, mock_print_initializers: MagicMock, capsys): + """Test do_list_initializers handles exceptions.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + mock_print_initializers.side_effect = ValueError("Test error") + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_list_initializers("") + + captured = capsys.readouterr() + assert "Error listing initializers" in captured.out + + def test_do_run_empty_line(self, capsys): + """Test do_run with empty line.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_run("") + + captured = capsys.readouterr() + assert "Specify a scenario name" in captured.out + + @patch("pyrit.cli.pyrit_shell.asyncio.run") + @patch("pyrit.cli.frontend_core.run_scenario_async") + @patch("pyrit.cli.frontend_core.parse_run_arguments") + def test_do_run_basic_scenario( + self, + mock_parse_args: MagicMock, + mock_run_scenario: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test do_run with basic scenario.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context._scenario_registry = MagicMock() + mock_context._initializer_registry = MagicMock() + mock_context.initialize = MagicMock() + + mock_parse_args.return_value = { + "scenario_name": "test_scenario", + "initializers": ["test_init"], + "initialization_scripts": None, + "scenario_strategies": None, + "max_concurrency": None, + "max_retries": None, + "memory_labels": None, + "database": None, + "log_level": None, + } + + mock_result = MagicMock() + mock_asyncio_run.return_value = mock_result + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_run("test_scenario --initializers test_init") + + mock_parse_args.assert_called_once() + mock_asyncio_run.assert_called_once() + + # Verify result was stored in history + assert len(shell._scenario_history) == 1 + assert shell._scenario_history[0][0] == "test_scenario --initializers test_init" + assert shell._scenario_history[0][1] == mock_result + + @patch("pyrit.cli.frontend_core.parse_run_arguments") + def test_do_run_parse_error(self, mock_parse_args: MagicMock, capsys): + """Test do_run with parse error.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + mock_parse_args.side_effect = ValueError("Parse error") + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_run("test_scenario --invalid") + + captured = capsys.readouterr() + assert "Error: Parse error" in captured.out + + @patch("pyrit.cli.pyrit_shell.asyncio.run") + @patch("pyrit.cli.frontend_core.run_scenario_async") + @patch("pyrit.cli.frontend_core.parse_run_arguments") + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + def test_do_run_with_initialization_scripts( + self, + mock_resolve_scripts: MagicMock, + mock_parse_args: MagicMock, + mock_run_scenario: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test do_run with initialization scripts.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context._scenario_registry = MagicMock() + mock_context._initializer_registry = MagicMock() + mock_context.initialize = MagicMock() + + mock_parse_args.return_value = { + "scenario_name": "test_scenario", + "initializers": None, + "initialization_scripts": ["script.py"], + "scenario_strategies": None, + "max_concurrency": None, + "max_retries": None, + "memory_labels": None, + "database": None, + "log_level": None, + } + + mock_resolve_scripts.return_value = [Path("/test/script.py")] + mock_asyncio_run.return_value = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_run("test_scenario --initialization-scripts script.py") + + mock_resolve_scripts.assert_called_once_with(script_paths=["script.py"]) + mock_asyncio_run.assert_called_once() + + @patch("pyrit.cli.frontend_core.parse_run_arguments") + @patch("pyrit.cli.frontend_core.resolve_initialization_scripts") + def test_do_run_with_missing_script( + self, + mock_resolve_scripts: MagicMock, + mock_parse_args: MagicMock, + capsys, + ): + """Test do_run with missing initialization script.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + mock_parse_args.return_value = { + "scenario_name": "test_scenario", + "initializers": None, + "initialization_scripts": ["missing.py"], + "scenario_strategies": None, + "max_concurrency": None, + "max_retries": None, + "memory_labels": None, + "database": None, + "log_level": None, + } + + mock_resolve_scripts.side_effect = FileNotFoundError("Script not found") + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_run("test_scenario --initialization-scripts missing.py") + + captured = capsys.readouterr() + assert "Error: Script not found" in captured.out + + @patch("pyrit.cli.pyrit_shell.asyncio.run") + @patch("pyrit.cli.frontend_core.parse_run_arguments") + def test_do_run_with_database_override( + self, + mock_parse_args: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test do_run with database override.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context._scenario_registry = MagicMock() + mock_context._initializer_registry = MagicMock() + mock_context.initialize = MagicMock() + + mock_parse_args.return_value = { + "scenario_name": "test_scenario", + "initializers": ["test_init"], + "initialization_scripts": None, + "scenario_strategies": None, + "max_concurrency": None, + "max_retries": None, + "memory_labels": None, + "database": "InMemory", + "log_level": None, + } + + mock_asyncio_run.return_value = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + with patch("pyrit.cli.frontend_core.FrontendCore") as mock_frontend: + shell.do_run("test_scenario --initializers test_init --database InMemory") + + # Verify FrontendCore was created with overridden database + call_kwargs = mock_frontend.call_args[1] + assert call_kwargs["database"] == "InMemory" + + @patch("pyrit.cli.pyrit_shell.asyncio.run") + @patch("pyrit.cli.frontend_core.parse_run_arguments") + def test_do_run_with_exception( + self, + mock_parse_args: MagicMock, + mock_asyncio_run: MagicMock, + capsys, + ): + """Test do_run handles exceptions during scenario run.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context._scenario_registry = MagicMock() + mock_context._initializer_registry = MagicMock() + mock_context.initialize = MagicMock() + + mock_parse_args.return_value = { + "scenario_name": "test_scenario", + "initializers": ["test_init"], + "initialization_scripts": None, + "scenario_strategies": None, + "max_concurrency": None, + "max_retries": None, + "memory_labels": None, + "database": None, + "log_level": None, + } + + mock_asyncio_run.side_effect = ValueError("Test error") + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_run("test_scenario --initializers test_init") + + captured = capsys.readouterr() + assert "Error: Test error" in captured.out + + def test_do_scenario_history_empty(self, capsys): + """Test do_scenario_history with no history.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_scenario_history("") + + captured = capsys.readouterr() + assert "No scenario runs in history" in captured.out + + def test_do_scenario_history_with_runs(self, capsys): + """Test do_scenario_history with scenario runs.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell._scenario_history = [ + ("test_scenario1 --initializers init1", MagicMock()), + ("test_scenario2 --initializers init2", MagicMock()), + ] + + shell.do_scenario_history("") + + captured = capsys.readouterr() + assert "Scenario Run History" in captured.out + assert "test_scenario1" in captured.out + assert "test_scenario2" in captured.out + assert "Total runs: 2" in captured.out + + def test_do_print_scenario_empty(self, capsys): + """Test do_print_scenario with no history.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.do_print_scenario("") + + captured = capsys.readouterr() + assert "No scenario runs in history" in captured.out + + @patch("pyrit.cli.pyrit_shell.asyncio.run") + @patch("pyrit.scenarios.printer.console_printer.ConsoleScenarioResultPrinter") + def test_do_print_scenario_all( + self, + mock_printer_class: MagicMock, + mock_asyncio_run: MagicMock, + capsys, + ): + """Test do_print_scenario without argument prints all.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + mock_printer = MagicMock() + mock_printer_class.return_value = mock_printer + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell._scenario_history = [ + ("test_scenario1", MagicMock()), + ("test_scenario2", MagicMock()), + ] + + shell.do_print_scenario("") + + captured = capsys.readouterr() + assert "Printing all scenario results" in captured.out + assert mock_asyncio_run.call_count == 2 + + @patch("pyrit.cli.pyrit_shell.asyncio.run") + @patch("pyrit.scenarios.printer.console_printer.ConsoleScenarioResultPrinter") + def test_do_print_scenario_specific( + self, + mock_printer_class: MagicMock, + mock_asyncio_run: MagicMock, + capsys, + ): + """Test do_print_scenario with specific scenario number.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + mock_printer = MagicMock() + mock_printer_class.return_value = mock_printer + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell._scenario_history = [ + ("test_scenario1", MagicMock()), + ("test_scenario2", MagicMock()), + ] + + shell.do_print_scenario("1") + + captured = capsys.readouterr() + assert "Scenario Run #1" in captured.out + assert mock_asyncio_run.call_count == 1 + + def test_do_print_scenario_invalid_number(self, capsys): + """Test do_print_scenario with invalid scenario number.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell._scenario_history = [ + ("test_scenario1", MagicMock()), + ] + + shell.do_print_scenario("5") + + captured = capsys.readouterr() + assert "must be between 1 and 1" in captured.out + + def test_do_print_scenario_non_integer(self, capsys): + """Test do_print_scenario with non-integer argument.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell._scenario_history = [ + ("test_scenario1", MagicMock()), + ] + + shell.do_print_scenario("invalid") + + captured = capsys.readouterr() + assert "Invalid scenario number" in captured.out + + def test_do_help_without_arg(self, capsys): + """Test do_help without argument.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + # Capture help output + with patch("cmd.Cmd.do_help"): + shell.do_help("") + captured = capsys.readouterr() + assert "Shell Startup Options" in captured.out + + def test_do_help_with_arg(self): + """Test do_help with specific command.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + with patch("cmd.Cmd.do_help") as mock_parent_help: + shell.do_help("run") + mock_parent_help.assert_called_with("run") + + def test_do_exit(self, capsys): + """Test do_exit command.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + result = shell.do_exit("") + + assert result is True + captured = capsys.readouterr() + assert "Goodbye" in captured.out + + def test_do_quit_alias(self): + """Test do_quit is alias for do_exit.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + assert shell.do_quit == shell.do_exit + + def test_do_q_alias(self): + """Test do_q is alias for do_exit.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + assert shell.do_q == shell.do_exit + + def test_do_eof_alias(self): + """Test do_EOF is alias for do_exit.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + assert shell.do_EOF == shell.do_exit + + @patch("os.system") + def test_do_clear_windows(self, mock_system: MagicMock): + """Test do_clear on Windows.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + with patch("os.name", "nt"): + shell.do_clear("") + mock_system.assert_called_with("cls") + + @patch("os.system") + def test_do_clear_unix(self, mock_system: MagicMock): + """Test do_clear on Unix.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + with patch("os.name", "posix"): + shell.do_clear("") + mock_system.assert_called_with("clear") + + def test_emptyline(self): + """Test emptyline doesn't repeat last command.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + result = shell.emptyline() + + assert result is False + + def test_default_with_hyphen_to_underscore(self): + """Test default converts hyphens to underscores.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + # Mock a method with underscores + shell.do_list_scenarios = MagicMock() + + shell.default("list-scenarios") + + shell.do_list_scenarios.assert_called_once_with("") + + def test_default_unknown_command(self, capsys): + """Test default with unknown command.""" + mock_context = MagicMock() + mock_context.initialize = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + shell.default("unknown_command") + + captured = capsys.readouterr() + assert "Unknown command" in captured.out + + +class TestMain: + """Tests for main function.""" + + @patch("pyrit.cli.pyrit_shell.PyRITShell") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_default_args(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock): + """Test main with default arguments.""" + mock_shell = MagicMock() + mock_shell_class.return_value = mock_shell + + with patch("sys.argv", ["pyrit_shell"]): + result = pyrit_shell.main() + + assert result == 0 + mock_frontend_core.assert_called_once() + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["database"] == "SQLite" + assert call_kwargs["log_level"] == "WARNING" + mock_shell.cmdloop.assert_called_once() + + @patch("pyrit.cli.pyrit_shell.PyRITShell") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_with_database_arg(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock): + """Test main with database argument.""" + mock_shell = MagicMock() + mock_shell_class.return_value = mock_shell + + with patch("sys.argv", ["pyrit_shell", "--database", "InMemory"]): + result = pyrit_shell.main() + + assert result == 0 + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["database"] == "InMemory" + + @patch("pyrit.cli.pyrit_shell.PyRITShell") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_with_log_level_arg(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock): + """Test main with log-level argument.""" + mock_shell = MagicMock() + mock_shell_class.return_value = mock_shell + + with patch("sys.argv", ["pyrit_shell", "--log-level", "DEBUG"]): + result = pyrit_shell.main() + + assert result == 0 + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["log_level"] == "DEBUG" + + @patch("pyrit.cli.pyrit_shell.PyRITShell") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_with_keyboard_interrupt(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock, capsys): + """Test main handles keyboard interrupt.""" + mock_shell = MagicMock() + mock_shell.cmdloop.side_effect = KeyboardInterrupt() + mock_shell_class.return_value = mock_shell + + with patch("sys.argv", ["pyrit_shell"]): + result = pyrit_shell.main() + + assert result == 0 + captured = capsys.readouterr() + assert "Interrupted" in captured.out + + @patch("pyrit.cli.pyrit_shell.PyRITShell") + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_with_exception(self, mock_frontend_core: MagicMock, mock_shell_class: MagicMock, capsys): + """Test main handles exceptions.""" + mock_shell = MagicMock() + mock_shell.cmdloop.side_effect = ValueError("Test error") + mock_shell_class.return_value = mock_shell + + with patch("sys.argv", ["pyrit_shell"]): + result = pyrit_shell.main() + + assert result == 1 + captured = capsys.readouterr() + assert "Error:" in captured.out + + @patch("pyrit.cli.frontend_core.FrontendCore") + def test_main_creates_context_without_initializers(self, mock_frontend_core: MagicMock): + """Test main creates context without initializers.""" + with patch("pyrit.cli.pyrit_shell.PyRITShell"): + with patch("sys.argv", ["pyrit_shell"]): + pyrit_shell.main() + + call_kwargs = mock_frontend_core.call_args[1] + assert call_kwargs["initialization_scripts"] is None + assert call_kwargs["initializer_names"] is None + + +class TestPyRITShellRunCommand: + """Detailed tests for the run command.""" + + @patch("pyrit.cli.pyrit_shell.asyncio.run") + @patch("pyrit.cli.frontend_core.parse_run_arguments") + def test_run_with_all_parameters( + self, + mock_parse_args: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test run command with all parameters.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context._scenario_registry = MagicMock() + mock_context._initializer_registry = MagicMock() + mock_context.initialize = MagicMock() + + mock_parse_args.return_value = { + "scenario_name": "test_scenario", + "initializers": ["init1"], + "initialization_scripts": None, + "scenario_strategies": ["s1", "s2"], + "max_concurrency": 10, + "max_retries": 5, + "memory_labels": {"key": "value"}, + "database": "InMemory", + "log_level": "DEBUG", + } + + mock_asyncio_run.return_value = MagicMock() + + shell = pyrit_shell.PyRITShell(context=mock_context) + + with patch("pyrit.cli.frontend_core.FrontendCore"): + with patch("pyrit.cli.frontend_core.run_scenario_async"): + shell.do_run("test_scenario --initializers init1 --strategies s1 s2 --max-concurrency 10") + + # Verify run_scenario_async was called with correct args + # (it's called via asyncio.run, so check the mock_asyncio_run call) + assert mock_asyncio_run.call_count == 1 + + @patch("pyrit.cli.pyrit_shell.asyncio.run") + @patch("pyrit.cli.frontend_core.parse_run_arguments") + def test_run_stores_result_in_history( + self, + mock_parse_args: MagicMock, + mock_asyncio_run: MagicMock, + ): + """Test run command stores result in history.""" + mock_context = MagicMock() + mock_context._database = "SQLite" + mock_context._log_level = "WARNING" + mock_context._scenario_registry = MagicMock() + mock_context._initializer_registry = MagicMock() + mock_context.initialize = MagicMock() + + mock_parse_args.return_value = { + "scenario_name": "test_scenario", + "initializers": ["test_init"], + "initialization_scripts": None, + "scenario_strategies": None, + "max_concurrency": None, + "max_retries": None, + "memory_labels": None, + "database": None, + "log_level": None, + } + + mock_result1 = MagicMock() + mock_result2 = MagicMock() + mock_asyncio_run.side_effect = [mock_result1, mock_result2] + + shell = pyrit_shell.PyRITShell(context=mock_context) + + # Run two scenarios + shell.do_run("scenario1 --initializers init1") + shell.do_run("scenario2 --initializers init2") + + # Verify both are in history + assert len(shell._scenario_history) == 2 + assert shell._scenario_history[0][1] == mock_result1 + assert shell._scenario_history[1][1] == mock_result2 diff --git a/tests/unit/cli/test_scenario_registry.py b/tests/unit/cli/test_scenario_registry.py index e86e944d8..1b938fd00 100644 --- a/tests/unit/cli/test_scenario_registry.py +++ b/tests/unit/cli/test_scenario_registry.py @@ -78,6 +78,7 @@ def test_get_scenario_names_empty(self): """Test get_scenario_names with no scenarios.""" registry = ScenarioRegistry() registry._scenarios = {} + registry._discovered = True # Prevent auto-discovery names = registry.get_scenario_names() assert names == [] @@ -89,6 +90,7 @@ def test_get_scenario_names_sorted(self): "apple_scenario": MockScenario, "middle_scenario": MockScenario, } + registry._discovered = True # Prevent auto-discovery names = registry.get_scenario_names() assert names == ["apple_scenario", "middle_scenario", "zebra_scenario"] @@ -114,6 +116,7 @@ def get_default_strategy(cls) -> ScenarioStrategy: registry._scenarios = { "test_scenario": DocumentedScenario, } + registry._discovered = True # Prevent auto-discovery scenarios = registry.list_scenarios() @@ -142,6 +145,7 @@ def get_default_strategy(cls) -> ScenarioStrategy: registry = ScenarioRegistry() registry._scenarios = {"undocumented": UndocumentedScenario} + registry._discovered = True # Prevent auto-discovery scenarios = registry.list_scenarios()