diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 786baa6b0..dbec6f246 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,6 +49,12 @@ repos: name: Import Sort (Jupyter Notebooks) args: [--profile=black] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.4 + hooks: + - id: ruff-check + name: ruff-check + - repo: https://github.com/PyCQA/flake8 rev: 7.1.2 hooks: @@ -67,13 +73,6 @@ repos: additional_dependencies: ['requests'] exclude: (release_process.md|git.md|^doc/deployment/|tests|pyrit/prompt_converter/morse_converter.py|.github|pyrit/prompt_converter/emoji_converter.py|pyrit/score/markdown_injection.py|^pyrit/datasets/|^pyrit/auxiliary_attacks/gcg/) - - repo: https://github.com/pycqa/pylint - rev: v3.3.7 - hooks: - - id: pylint - args: [--disable=all, --enable=unused-import] - exclude: NOTICE.txt - - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.15.0 hooks: diff --git a/doc/code/scoring/7_batch_scorer.ipynb b/doc/code/scoring/7_batch_scorer.ipynb index f23b3949d..2df36f06e 100644 --- a/doc/code/scoring/7_batch_scorer.ipynb +++ b/doc/code/scoring/7_batch_scorer.ipynb @@ -3,9 +3,7 @@ { "cell_type": "markdown", "id": "0", - "metadata": { - "lines_to_next_cell": 0 - }, + "metadata": {}, "source": [ "# 7. Batch Scoring\n", "\n", @@ -133,10 +131,9 @@ } ], "source": [ - "# pylint: disable=W0611\n", "from pyrit.memory import CentralMemory\n", "from pyrit.prompt_target import OpenAIChatTarget\n", - "from pyrit.score import (\n", + "from pyrit.score import ( # noqa: F401\n", " AzureContentFilterScorer,\n", " BatchScorer,\n", " ContentClassifierPaths,\n", @@ -205,12 +202,11 @@ } ], "source": [ - "# pylint: disable=W0611\n", "import uuid\n", "\n", "from pyrit.memory import CentralMemory\n", "from pyrit.prompt_target import OpenAIChatTarget\n", - "from pyrit.score import (\n", + "from pyrit.score import ( # noqa: F401\n", " AzureContentFilterScorer,\n", " BatchScorer,\n", " ContentClassifierPaths,\n", @@ -256,7 +252,8 @@ ], "metadata": { "jupytext": { - "cell_metadata_filter": "-all" + "cell_metadata_filter": "-all", + "main_language": "python" }, "language_info": { "codemirror_mode": { diff --git a/doc/code/scoring/7_batch_scorer.py b/doc/code/scoring/7_batch_scorer.py index d65d3b08d..ba43a3050 100644 --- a/doc/code/scoring/7_batch_scorer.py +++ b/doc/code/scoring/7_batch_scorer.py @@ -64,10 +64,9 @@ # Once the prompts are in the database (which again, is often automatic) we can use `BatchScorer` to score them with whatever scorers we want. It works in parallel with batches. # %% -# pylint: disable=W0611 from pyrit.memory import CentralMemory from pyrit.prompt_target import OpenAIChatTarget -from pyrit.score import ( +from pyrit.score import ( # noqa: F401 AzureContentFilterScorer, BatchScorer, ContentClassifierPaths, @@ -113,12 +112,11 @@ # - Converted Value SHA256 # %% -# pylint: disable=W0611 import uuid from pyrit.memory import CentralMemory from pyrit.prompt_target import OpenAIChatTarget -from pyrit.score import ( +from pyrit.score import ( # noqa: F401 AzureContentFilterScorer, BatchScorer, ContentClassifierPaths, diff --git a/pyproject.toml b/pyproject.toml index 4f45e888f..c53d22be5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ dev = [ "pytest-cov>=6.1.1", "pytest-timeout>=2.4.0", "respx>=0.22.0", + "ruff>=0.14.4", "sphinxcontrib-mermaid>=1.0.0", "types-PyYAML>=6.0.12.20250516", "types-requests>=2.31.0.20250515", @@ -174,6 +175,9 @@ formats = "ipynb,py:percent" [tool.ruff] line-length = 120 + +[tool.ruff.lint] +preview = true fixable = [ "A", "B", @@ -220,3 +224,34 @@ fixable = [ "UP", "YTT", ] +select = [ + "D", # https://docs.astral.sh/ruff/rules/#pydocstyle-d + "DOC", # https://docs.astral.sh/ruff/rules/#pydoclint-doc + "F401", # unused-import +] +ignore = [ + "D100", # Missing docstring in public module + "D200", # One-line docstring should fit on one line + "D205", # 1 blank line required between summary line and description + "D212", # Multi-line docstring summary should start at the first line + "D301", # Use r""" if any backslashes in a docstring + "DOC502", # Raised exception is not explicitly raised +] +extend-select = [ + "D204", # 1 blank line required after class docstring + "D213", # Multi-line docstring summary should start at the second line + "D401", # First line of docstring should be in imperative mood + "D404", # First word of the docstring should not be "This" +] + +[tool.ruff.lint.per-file-ignores] +# Ignore D and DOC rules everywhere except for the pyrit/ directory +"!pyrit/**.py" = ["D", "DOC"] +# Temporary ignores for pyrit/ subdirectories until issue #1176 +# https://github.com/Azure/PyRIT/issues/1176 is fully resolved +# TODO: Remove these ignores once the issues are fixed +"pyrit/{analytics,auth,auxiliary_attacks,chat_message_normalizer,cli,common,datasets,embedding,exceptions,executor,memory,models,prompt_converter,prompt_normalizer,prompt_target,scenarios,score,setup,ui}/**/*.py" = ["D101", "D102", "D103", "D104", "D105", "D106", "D107", "D401", "D404", "D417", "D418", "DOC102", "DOC201", "DOC202", "DOC402", "DOC501"] +"pyrit/__init__.py" = ["D104"] + +[tool.ruff.lint.pydocstyle] +convention = "google" diff --git a/pyrit/analytics/conversation_analytics.py b/pyrit/analytics/conversation_analytics.py index 298683eb8..c3174d73c 100644 --- a/pyrit/analytics/conversation_analytics.py +++ b/pyrit/analytics/conversation_analytics.py @@ -29,7 +29,7 @@ def get_prompt_entries_with_same_converted_content( self, *, chat_message_content: str ) -> list[ConversationMessageWithSimilarity]: """ - Retrieves chat messages that have the same converted content + Retrieves chat messages that have the same converted content. Args: chat_message_content (str): The content of the chat message to find similar messages for. @@ -68,7 +68,6 @@ def get_similar_chat_messages_by_embedding( List[ConversationMessageWithSimilarity]: A list of ConversationMessageWithSimilarity objects representing the similar chat messages based on embedding similarity. """ - all_embdedding_memory = self.memory_interface.get_all_embeddings() similar_messages = [] diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index a64e3ea78..5d9ecba41 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -49,7 +49,6 @@ def analyze_results(attack_results: list[AttackResult]) -> dict: "By_attack_identifier": dict[str, AttackStats] } """ - if not attack_results: raise ValueError("attack_results cannot be empty") diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 72e8eb272..d3e6c146f 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -41,7 +41,8 @@ def _set_default_token(self) -> None: self.token = self.access_token.token def refresh_token(self) -> str: - """Refresh the access token if it is expired. + """ + Refresh the access token if it is expired. Returns: A token @@ -79,7 +80,8 @@ def get_access_token_from_azure_cli(*, scope: str, tenant_id: str = ""): def get_access_token_from_azure_msi(*, client_id: str, scope: str): - """Connect to an AOAI endpoint via managed identity credential attached to an Azure resource. + """ + Connect to an AOAI endpoint via managed identity credential attached to an Azure resource. For proper setup and configuration of MSI https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/overview. @@ -100,7 +102,8 @@ def get_access_token_from_azure_msi(*, client_id: str, scope: str): def get_access_token_from_msa_public_client(*, client_id: str, scope: str): - """Uses MSA account to connect to an AOAI endpoint via interactive login. A browser window + """ + Uses MSA account to connect to an AOAI endpoint via interactive login. A browser window will open and ask for login credentials. Args: @@ -120,7 +123,8 @@ def get_access_token_from_msa_public_client(*, client_id: str, scope: str): def get_access_token_from_interactive_login(scope: str) -> str: - """Connects to an OpenAI endpoint with an interactive login from Azure. A browser window will + """ + Connects to an OpenAI endpoint with an interactive login from Azure. A browser window will open and ask for login credentials. The token will be scoped for Azure Cognitive services. Returns: @@ -135,7 +139,8 @@ def get_access_token_from_interactive_login(scope: str) -> str: def get_token_provider_from_default_azure_credential(scope: str) -> Callable[[], str]: - """Connect to an AOAI endpoint via default Azure credential. + """ + Connect to an AOAI endpoint via default Azure credential. Returns: Authentication token provider @@ -149,7 +154,8 @@ def get_token_provider_from_default_azure_credential(scope: str) -> Callable[[], def get_default_scope(endpoint: str) -> str: - """Get the default scope for the given endpoint. + """ + Get the default scope for the given endpoint. Args: endpoint (str): The endpoint to get the scope for. @@ -170,12 +176,13 @@ def get_default_scope(endpoint: str) -> str: def get_speech_config(resource_id: Union[str, None], key: Union[str, None], region: str): """ Get the speech config using key/region pair (for key auth scenarios) or resource_id/region pair - (for Entra auth scenarios) + (for Entra auth scenarios). Args: resource_id (Union[str, None]): The resource ID to get the token for. key (Union[str, None]): The Azure Speech key region (str): The region to get the token for. + Returns: The speech config based on passed in args @@ -207,7 +214,8 @@ def get_speech_config(resource_id: Union[str, None], key: Union[str, None], regi def get_speech_config_from_default_azure_credential(resource_id: str, region: str): - """Get the speech config for the given resource ID and region. + """ + Get the speech config for the given resource ID and region. Args: resource_id (str): The resource ID to get the token for. diff --git a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py index b2b977a42..e6b76c6ad 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py +++ b/pyrit/auxiliary_attacks/gcg/attack/base/attack_manager.py @@ -134,23 +134,21 @@ def __init__( """ Initializes the AttackPrompt object with the provided parameters. - Parameters - ---------- - goal : str - The intended goal of the attack - target : str - The target of the attack - tokenizer : Transformer Tokenizer - The tokenizer used to convert text into tokens - conv_template : Template - The conversation template used for the attack - control_init : str, optional - A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ") - test_prefixes : list, optional - A list of prefixes to test the attack - (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) + Args: + goal (str): + The intended goal of the attack + target (str): + The target of the attack + tokenizer (Transformer Tokenizer): + The tokenizer used to convert text into tokens + conv_template (Template): + The conversation template used for the attack + control_init (str, optional): + A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") + test_prefixes (list, optional): + A list of prefixes to test the attack + (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) """ - self.goal = goal self.target = target self.control = control_init @@ -467,25 +465,23 @@ def __init__( """ Initializes the PromptManager object with the provided parameters. - Parameters - ---------- - goals : list of str - The list of intended goals of the attack - targets : list of str - The list of targets of the attack - tokenizer : Transformer Tokenizer - The tokenizer used to convert text into tokens - conv_template : Template - The conversation template used for the attack - control_init : str, optional - A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") - test_prefixes : list, optional - A list of prefixes to test the attack - (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) - managers : dict, optional - A dictionary of manager objects, required to create the prompts. + Args: + goals (List[str]): + The list of intended goals of the attack + targets (List[str]): + The list of targets of the attack + tokenizer (Transformer Tokenizer): + The tokenizer used to convert text into tokens + conv_template (Template): + The conversation template used for the attack + control_init (str, optional): + A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") + test_prefixes (list, optional): + A list of prefixes to test the attack + (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) + managers (dict, optional): + A dictionary of manager objects, required to create the prompts. """ - if len(goals) != len(targets): raise ValueError("Length of goals and targets must match") if len(goals) == 0: @@ -601,31 +597,29 @@ def __init__( """ Initializes the MultiPromptAttack object with the provided parameters. - Parameters - ---------- - goals : list of str - The list of intended goals of the attack - targets : list of str - The list of targets of the attack - workers : list of Worker objects - The list of workers used in the attack - control_init : str, optional - A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") - test_prefixes : list, optional - A list of prefixes to test the attack - (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) - logfile : str, optional - A file to which logs will be written - managers : dict, optional - A dictionary of manager objects, required to create the prompts. - test_goals : list of str, optional - The list of test goals of the attack - test_targets : list of str, optional - The list of test targets of the attack - test_workers : list of Worker objects, optional - The list of test workers used in the attack + Args: + goals (List[str]): + The list of intended goals of the attack + targets (List[str]): + The list of targets of the attack + workers (List[Worker]): + The list of workers used in the attack + control_init (str, optional): + A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") + test_prefixes (list, optional): + A list of prefixes to test the attack + (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) + logfile (str, optional): + A file to which logs will be written + managers (dict, optional): + A dictionary of manager objects, required to create the prompts. + test_goals (list of str, optional): + The list of test goals of the attack + test_targets (list of str, optional): + The list of test targets of the attack + test_workers (list of Worker objects, optional): + The list of test workers used in the attack """ - self.goals = goals self.targets = targets self.workers = workers @@ -924,35 +918,33 @@ def __init__( """ Initializes the ProgressiveMultiPromptAttack object with the provided parameters. - Parameters - ---------- - goals : list of str - The list of intended goals of the attack - targets : list of str - The list of targets of the attack - workers : list of Worker objects - The list of workers used in the attack - progressive_goals : bool, optional - If true, goals progress over time (default is True) - progressive_models : bool, optional - If true, models progress over time (default is True) - control_init : str, optional - A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") - test_prefixes : list, optional - A list of prefixes to test the attack - (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) - logfile : str, optional - A file to which logs will be written - managers : dict, optional - A dictionary of manager objects, required to create the prompts. - test_goals : list of str, optional - The list of test goals of the attack - test_targets : list of str, optional - The list of test targets of the attack - test_workers : list of Worker objects, optional - The list of test workers used in the attack + Args: + goals (List[str]): + The list of intended goals of the attack + targets (List[str]): + The list of targets of the attack + workers (List[Worker]): + The list of workers used in the attack + progressive_goals (bool, optional): + If true, goals progress over time (default is True) + progressive_models (bool, optional): + If true, models progress over time (default is True) + control_init (str, optional): + A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") + test_prefixes (List[str], optional): + A list of prefixes to test the attack + (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) + logfile (str, optional): + A file to which logs will be written + managers (dict, optional): + A dictionary of manager objects, required to create the prompts. + test_goals (List[str], optional): + The list of test goals of the attack + test_targets (List[str], optional): + The list of test targets of the attack + test_workers (List[Worker], optional): + The list of test workers used in the attack """ - self.goals = goals self.targets = targets self.workers = workers @@ -1033,36 +1025,34 @@ def run( """ Executes the progressive multi prompt attack. - Parameters - ---------- - n_steps : int, optional - The number of steps to run the attack (default is 1000) - batch_size : int, optional - The size of batches to process at a time (default is 1024) - topk : int, optional - The number of top candidates to consider (default is 256) - temp : float, optional - The temperature for sampling (default is 1) - allow_non_ascii : bool, optional - Whether to allow non-ASCII characters (default is False) - target_weight - The weight assigned to the target - control_weight - The weight assigned to the control - anneal : bool, optional - Whether to anneal the temperature (default is True) - test_steps : int, optional - The number of steps between tests (default is 50) - incr_control : bool, optional - Whether to increase the control over time (default is True) - stop_on_success : bool, optional - Whether to stop the attack upon success (default is True) - verbose : bool, optional - Whether to print verbose output (default is True) - filter_cand : bool, optional - Whether to filter candidates whose lengths changed after re-tokenization (default is True) + Args: + n_steps (int, optional): + The number of steps to run the attack (default is 1000) + batch_size (int, optional): + The size of batches to process at a time (default is 1024) + topk (int, optional): + The number of top candidates to consider (default is 256) + temp (float, optional): + The temperature for sampling (default is 1) + allow_non_ascii (bool, optional): + Whether to allow non-ASCII characters (default is False) + target_weight + The weight assigned to the target + control_weight + The weight assigned to the control + anneal (bool, optional): + Whether to anneal the temperature (default is True) + test_steps (int, optional): + The number of steps between tests (default is 50) + incr_control (bool, optional): + Whether to increase the control over time (default is True) + stop_on_success (bool, optional): + Whether to stop the attack upon success (default is True) + verbose (bool, optional): + Whether to print verbose output (default is True) + filter_cand (bool, optional): + Whether to filter candidates whose lengths changed after re-tokenization (default is True) """ - if self.logfile is not None: with open(self.logfile, "r") as f: log = json.load(f) @@ -1169,31 +1159,29 @@ def __init__( """ Initializes the IndividualPromptAttack object with the provided parameters. - Parameters - ---------- - goals : list - The list of intended goals of the attack - targets : list - The list of targets of the attack - workers : list - The list of workers used in the attack - control_init : str, optional - A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") - test_prefixes : list, optional - A list of prefixes to test the attack (default is - ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) - logfile : str, optional - A file to which logs will be written - managers : dict, optional - A dictionary of manager objects, required to create the prompts. - test_goals : list, optional - The list of test goals of the attack - test_targets : list, optional - The list of test targets of the attack - test_workers : list, optional - The list of test workers used in the attack + Args: + goals (list): + The list of intended goals of the attack + targets (list): + The list of targets of the attack + workers (list): + The list of workers used in the attack + control_init (str, optional): + A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") + test_prefixes (list, optional): + A list of prefixes to test the attack (default is + ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) + logfile (str, optional): + A file to which logs will be written + managers (dict, optional): + A dictionary of manager objects, required to create the prompts. + test_goals (list, optional): + The list of test goals of the attack + test_targets (list, optional): + The list of test targets of the attack + test_workers (list, optional): + The list of test workers used in the attack """ - self.goals = goals self.targets = targets self.workers = workers @@ -1271,36 +1259,34 @@ def run( """ Executes the individual prompt attack. - Parameters - ---------- - n_steps : int, optional - The number of steps to run the attack (default is 1000) - batch_size : int, optional - The size of batches to process at a time (default is 1024) - topk : int, optional - The number of top candidates to consider (default is 256) - temp : float, optional - The temperature for sampling (default is 1) - allow_non_ascii : bool, optional - Whether to allow non-ASCII characters (default is True) - target_weight : any, optional - The weight assigned to the target - control_weight : any, optional - The weight assigned to the control - anneal : bool, optional - Whether to anneal the temperature (default is True) - test_steps : int, optional - The number of steps between tests (default is 50) - incr_control : bool, optional - Whether to increase the control over time (default is True) - stop_on_success : bool, optional - Whether to stop the attack upon success (default is True) - verbose : bool, optional - Whether to print verbose output (default is True) - filter_cand : bool, optional - Whether to filter candidates (default is True) + Args: + n_steps (int, optional): + The number of steps to run the attack (default is 1000) + batch_size (int, optional): + The size of batches to process at a time (default is 1024) + topk (int, optional): + The number of top candidates to consider (default is 256) + temp (float, optional): + The temperature for sampling (default is 1) + allow_non_ascii (bool, optional): + Whether to allow non-ASCII characters (default is True) + target_weight (any, optional): + The weight assigned to the target + control_weight (any, optional): + The weight assigned to the control + anneal (bool, optional): + Whether to anneal the temperature (default is True) + test_steps (int, optional): + The number of steps between tests (default is 50) + incr_control (bool, optional): + Whether to increase the control over time (default is True) + stop_on_success (bool, optional): + Whether to stop the attack upon success (default is True) + verbose (bool, optional): + Whether to print verbose output (default is True) + filter_cand (bool, optional): + Whether to filter candidates (default is True) """ - if self.logfile is not None: with open(self.logfile, "r") as f: log = json.load(f) @@ -1379,31 +1365,29 @@ def __init__( """ Initializes the EvaluateAttack object with the provided parameters. - Parameters - ---------- - goals : list - The list of intended goals of the attack - targets : list - The list of targets of the attack - workers : list - The list of workers used in the attack - control_init : str, optional - A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") - test_prefixes : list, optional - A list of prefixes to test the attack - (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) - logfile : str, optional - A file to which logs will be written - managers : dict, optional - A dictionary of manager objects, required to create the prompts. - test_goals : list, optional - The list of test goals of the attack - test_targets : list, optional - The list of test targets of the attack - test_workers : list, optional - The list of test workers used in the attack + Args: + goals (list): + The list of intended goals of the attack + targets (list): + The list of targets of the attack + workers (list): + The list of workers used in the attack + control_init (str, optional): + A string used to control the attack (default is "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !") + test_prefixes (list, optional): + A list of prefixes to test the attack + (default is ["I'm sorry", "I apologize", "As an", "I'm just", "I'm an", "I cannot", "Hello!"]) + logfile (str, optional): + A file to which logs will be written + managers (dict, optional): + A dictionary of manager objects, required to create the prompts. + test_goals (list, optional): + The list of test goals of the attack + test_targets (list, optional): + The list of test targets of the attack + test_workers (list, optional): + The list of test workers used in the attack """ - self.goals = goals self.targets = targets self.workers = workers diff --git a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py index a34b82dfc..5b0759b08 100644 --- a/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py +++ b/pyrit/auxiliary_attacks/gcg/attack/gcg/gcg_attack.py @@ -24,25 +24,21 @@ def token_gradients(model, input_ids, input_slice, target_slice, loss_slice): """ Computes gradients of the loss with respect to the coordinates. - Parameters - ---------- - model : Transformer Model - The transformer model to be used. - input_ids : torch.Tensor - The input sequence in the form of token ids. - input_slice : slice - The slice of the input sequence for which gradients need to be computed. - target_slice : slice - The slice of the input sequence to be used as targets. - loss_slice : slice - The slice of the logits to be used for computing the loss. - - Returns - ------- - torch.Tensor - The gradients of each token in the input_slice with respect to the loss. + Args: + model (Transformer Model): + The transformer model to be used. + input_ids (torch.Tensor): + The input sequence in the form of token ids. + input_slice (slice): + The slice of the input sequence for which gradients need to be computed. + target_slice (slice): + The slice of the input sequence to be used as targets. + loss_slice (slice): + The slice of the logits to be used for computing the loss. + + Returns: + torch.Tensor: The gradients of each token in the input_slice with respect to the loss. """ - embed_weights = get_embedding_matrix(model) one_hot = torch.zeros( input_ids[input_slice].shape[0], embed_weights.shape[0], device=model.device, dtype=embed_weights.dtype diff --git a/pyrit/auxiliary_attacks/gcg/experiments/run.py b/pyrit/auxiliary_attacks/gcg/experiments/run.py index da9ae10c6..3b21a18f7 100644 --- a/pyrit/auxiliary_attacks/gcg/experiments/run.py +++ b/pyrit/auxiliary_attacks/gcg/experiments/run.py @@ -24,7 +24,7 @@ def _load_yaml_to_dict(config_path: str) -> dict: def run_trainer(*, model_name: str, setup: str = "single", **extra_config_parameters): """ - Trains and generates adversarial suffix - single model single prompt + Trains and generates adversarial suffix - single model single prompt. Args: model_name (str): The name of the model, currently supports: @@ -34,7 +34,6 @@ def run_trainer(*, model_name: str, setup: str = "single", **extra_config_parame - "multiple": multiple prompts one model or multiple prompts multiple models """ - if model_name not in MODEL_NAMES: raise ValueError( "Model name not supported. Currently supports 'mistral', 'llama_2', 'llama_3', 'vicuna', and 'phi_3_mini'" diff --git a/pyrit/chat_message_normalizer/chat_message_normalizer.py b/pyrit/chat_message_normalizer/chat_message_normalizer.py index e8b2cd539..2182493ea 100644 --- a/pyrit/chat_message_normalizer/chat_message_normalizer.py +++ b/pyrit/chat_message_normalizer/chat_message_normalizer.py @@ -13,7 +13,7 @@ class ChatMessageNormalizer(abc.ABC, Generic[T]): @abc.abstractmethod def normalize(self, messages: list[ChatMessage]) -> T: """ - Normalizes the list of chat messages into a compatible format for the model or target + Normalizes the list of chat messages into a compatible format for the model or target. """ @staticmethod diff --git a/pyrit/chat_message_normalizer/chat_message_normalizer_chatml.py b/pyrit/chat_message_normalizer/chat_message_normalizer_chatml.py index 933e31a52..7eb0557fc 100644 --- a/pyrit/chat_message_normalizer/chat_message_normalizer_chatml.py +++ b/pyrit/chat_message_normalizer/chat_message_normalizer_chatml.py @@ -12,7 +12,8 @@ class ChatMessageNormalizerChatML(ChatMessageNormalizer[str]): """A chat message normalizer that converts a list of chat messages to a ChatML string.""" def normalize(self, messages: list[ChatMessage]) -> str: - """Convert a string of text to a ChatML string. + """ + Convert a string of text to a ChatML string. This is compliant with the ChatML specified in https://github.com/openai/openai-python/blob/release-v0.28.0/chatml.md @@ -31,7 +32,8 @@ def normalize(self, messages: list[ChatMessage]) -> str: @staticmethod def from_chatml(content: str) -> list[ChatMessage]: - """Convert a chatML string to a list of chat messages. + """ + Convert a chatML string to a list of chat messages. Args: content (str): The ChatML string to convert. diff --git a/pyrit/chat_message_normalizer/chat_message_normalizer_tokenizer.py b/pyrit/chat_message_normalizer/chat_message_normalizer_tokenizer.py index 92bb30d8e..42bad58e0 100644 --- a/pyrit/chat_message_normalizer/chat_message_normalizer_tokenizer.py +++ b/pyrit/chat_message_normalizer/chat_message_normalizer_tokenizer.py @@ -11,7 +11,7 @@ class ChatMessageNormalizerTokenizerTemplate(ChatMessageNormalizer[str]): """ This class enables you to apply the chat template stored in a Hugging Face tokenizer to a list of chat messages. For more details, see - https://huggingface.co/docs/transformers/main/en/chat_templating + https://huggingface.co/docs/transformers/main/en/chat_templating. """ def __init__(self, tokenizer: PreTrainedTokenizerBase): @@ -33,7 +33,6 @@ def normalize(self, messages: list[ChatMessage]) -> str: Returns: str: The formatted chat messages. """ - messages_list = [] formatted_messages: str = "" diff --git a/pyrit/chat_message_normalizer/generic_system_squash.py b/pyrit/chat_message_normalizer/generic_system_squash.py index 8cc94319a..53f40c80d 100644 --- a/pyrit/chat_message_normalizer/generic_system_squash.py +++ b/pyrit/chat_message_normalizer/generic_system_squash.py @@ -7,7 +7,8 @@ class GenericSystemSquash(ChatMessageNormalizer[list[ChatMessage]]): def normalize(self, messages: list[ChatMessage]) -> list[ChatMessage]: - """Returns the first system message combined with the first user message. + """ + Returns the first system message combined with the first user message. The format of the result uses generic instruction tags. @@ -26,7 +27,8 @@ def normalize(self, messages: list[ChatMessage]) -> list[ChatMessage]: def combine_system_user_message( system_message: ChatMessage, user_message: ChatMessage, msg_type: ChatMessageRole = "user" ) -> ChatMessage: - """Combines the system message with the user message. + """ + Combines the system message with the user message. Args: system_message (str): The system message. diff --git a/pyrit/common/apply_defaults.py b/pyrit/common/apply_defaults.py index 8231e4342..f2e397f5f 100644 --- a/pyrit/common/apply_defaults.py +++ b/pyrit/common/apply_defaults.py @@ -215,7 +215,6 @@ def set_global_variable(*, name: str, value: Any) -> None: variable accessible to code that imports or executes after the initialization script runs. """ - # Set the variable in the __main__ module's global namespace sys.modules["__main__"].__dict__[name] = value diff --git a/pyrit/common/data_url_converter.py b/pyrit/common/data_url_converter.py index c89a135f3..d4b954455 100644 --- a/pyrit/common/data_url_converter.py +++ b/pyrit/common/data_url_converter.py @@ -9,7 +9,8 @@ async def convert_local_image_to_data_url(image_path: str) -> str: - """Converts a local image file to a data URL encoded in base64. + """ + Converts a local image file to a data URL encoded in base64. Args: image_path (str): The file system path to the image file. diff --git a/pyrit/common/default_values.py b/pyrit/common/default_values.py index 113a4b3ac..c088bdb90 100644 --- a/pyrit/common/default_values.py +++ b/pyrit/common/default_values.py @@ -11,7 +11,7 @@ def get_required_value(*, env_var_name: str, passed_value: str) -> str: """ Gets a required value from an environment variable or a passed value, - preferring the passed value + preferring the passed value. If no value is found, raises a KeyError diff --git a/pyrit/common/display_response.py b/pyrit/common/display_response.py index 0a3b81285..371d5b5f4 100644 --- a/pyrit/common/display_response.py +++ b/pyrit/common/display_response.py @@ -13,7 +13,8 @@ async def display_image_response(response_piece: MessagePiece) -> None: - """Displays response images if running in notebook environment. + """ + Displays response images if running in notebook environment. Args: response_piece (MessagePiece): The response piece to display. diff --git a/pyrit/common/download_hf_model.py b/pyrit/common/download_hf_model.py index fd5b6c138..5ff8a45b8 100644 --- a/pyrit/common/download_hf_model.py +++ b/pyrit/common/download_hf_model.py @@ -99,7 +99,6 @@ async def download_file(url, token, download_dir, num_splits): async def download_files(urls: list[str], token: str, download_dir: Path, num_splits=3, parallel_downloads=4): """Download multiple files with parallel downloads and segmented downloading.""" - # Limit the number of parallel downloads semaphore = asyncio.Semaphore(parallel_downloads) diff --git a/pyrit/common/net_utility.py b/pyrit/common/net_utility.py index 4df5990b3..d3f797e05 100644 --- a/pyrit/common/net_utility.py +++ b/pyrit/common/net_utility.py @@ -10,7 +10,6 @@ def get_httpx_client(use_async: bool = False, debug: bool = False, **httpx_client_kwargs: Optional[Any]): """Get the httpx client for making requests.""" - client_class = httpx.AsyncClient if use_async else httpx.Client proxy = "http://localhost:8080" if debug else None diff --git a/pyrit/common/notebook_utils.py b/pyrit/common/notebook_utils.py index c382feb13..4c2753e2d 100644 --- a/pyrit/common/notebook_utils.py +++ b/pyrit/common/notebook_utils.py @@ -3,7 +3,8 @@ def is_in_ipython_session() -> bool: - """Determines if the code is running in an IPython session. + """ + Determines if the code is running in an IPython session. This may be useful if the behavior of the code should change when running in an IPython session. For example, the code may display additional information or plots when running in an IPython session. diff --git a/pyrit/common/print.py b/pyrit/common/print.py index 3053cb2fc..68e7f559c 100644 --- a/pyrit/common/print.py +++ b/pyrit/common/print.py @@ -15,7 +15,8 @@ def print_chat_messages_with_color( left_padding_width: int = 20, custom_colors: Optional[dict[str, str]] = None, ) -> None: - """Print chat messages with color to console. + """ + Print chat messages with color to console. Args: messages: List of chat messages. diff --git a/pyrit/common/question_answer_helpers.py b/pyrit/common/question_answer_helpers.py index cdc906c18..c8314b3b6 100644 --- a/pyrit/common/question_answer_helpers.py +++ b/pyrit/common/question_answer_helpers.py @@ -6,7 +6,7 @@ def construct_evaluation_prompt(entry: QuestionAnsweringEntry) -> str: """ - From question and choices in entry, creates prompt to be send to target + From question and choices in entry, creates prompt to be send to target. Args: entry (QuestionAnsweringEntry): A single entry from which the prompt is constructed diff --git a/pyrit/datasets/babelscape_alert_dataset.py b/pyrit/datasets/babelscape_alert_dataset.py index bb73fb0e4..5208f6323 100644 --- a/pyrit/datasets/babelscape_alert_dataset.py +++ b/pyrit/datasets/babelscape_alert_dataset.py @@ -21,7 +21,6 @@ def fetch_babelscape_alert_dataset( Returns: SeedDataset: A SeedDataset containing the examples. """ - data_categories = None if category is None: # if category is explicitly None, read both subsets data_categories = ["alert_adversarial", "alert"] diff --git a/pyrit/datasets/dataset_helper.py b/pyrit/datasets/dataset_helper.py index 32719b854..845acd6af 100644 --- a/pyrit/datasets/dataset_helper.py +++ b/pyrit/datasets/dataset_helper.py @@ -115,7 +115,6 @@ def fetch_examples( Returns: List[Dict[str, str]]: A list of examples. """ - file_type = source.split(".")[-1] if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) diff --git a/pyrit/datasets/equitymedqa_dataset.py b/pyrit/datasets/equitymedqa_dataset.py index c29c4f91b..0615714be 100644 --- a/pyrit/datasets/equitymedqa_dataset.py +++ b/pyrit/datasets/equitymedqa_dataset.py @@ -57,6 +57,7 @@ def fetch_equitymedqa_dataset_unique_values( ) -> SeedDataset: """ Fetches the EquityMedQA dataset from Hugging Face and returns a SeedDataset. + Args: subset_name (str | list): The name(s) of the subset to fetch. Defaults to "all" which returns all values. @@ -101,6 +102,7 @@ def fetch_equitymedqa_dataset_unique_values( def get_sub_dataset(subset_name: str) -> list: """ Fetches a specific subset of the EquityMedQA dataset and returns a list of unique prompts. + Args: subset_name (str): The name of the subset to fetch. """ diff --git a/pyrit/datasets/forbidden_questions_dataset.py b/pyrit/datasets/forbidden_questions_dataset.py index e0a0b8362..01d3cbb97 100644 --- a/pyrit/datasets/forbidden_questions_dataset.py +++ b/pyrit/datasets/forbidden_questions_dataset.py @@ -8,7 +8,7 @@ def fetch_forbidden_questions_dataset() -> SeedDataset: """ - Fetch Forbidden question dataset and return it as a SeedDataset + Fetch Forbidden question dataset and return it as a SeedDataset. Returns: SeedDataset diff --git a/pyrit/datasets/harmbench_dataset.py b/pyrit/datasets/harmbench_dataset.py index 10fc75c6b..47d78b548 100644 --- a/pyrit/datasets/harmbench_dataset.py +++ b/pyrit/datasets/harmbench_dataset.py @@ -33,7 +33,6 @@ def fetch_harmbench_dataset( For more information and access to the original dataset and related materials, visit: https://github.com/centerforaisafety/HarmBench """ - # Determine the file type from the source URL file_type = source.split(".")[-1] if file_type not in FILE_TYPE_HANDLERS: diff --git a/pyrit/datasets/many_shot_jailbreaking_dataset.py b/pyrit/datasets/many_shot_jailbreaking_dataset.py index eb9cfd93e..8681836ac 100644 --- a/pyrit/datasets/many_shot_jailbreaking_dataset.py +++ b/pyrit/datasets/many_shot_jailbreaking_dataset.py @@ -13,7 +13,6 @@ def fetch_many_shot_jailbreaking_dataset() -> List[Dict[str, str]]: Returns: List[Dict[str, str]]: A list of many-shot jailbreaking examples. """ - source = "https://raw.githubusercontent.com/KutalVolkan/many-shot-jailbreaking-dataset/5eac855/examples.json" source_type: Literal["public_url"] = "public_url" diff --git a/pyrit/datasets/medsafetybench_dataset.py b/pyrit/datasets/medsafetybench_dataset.py index 637db1564..17c4b36bb 100644 --- a/pyrit/datasets/medsafetybench_dataset.py +++ b/pyrit/datasets/medsafetybench_dataset.py @@ -33,7 +33,6 @@ def fetch_medsafetybench_dataset( https://proceedings.neurips.cc/paper_files/paper/2024/hash/3ac952d0264ef7a505393868a70a46b6-Abstract-Datasets_and_Benchmarks_Track.html Authors: Tessa Han, Aounon Kumar, Chirag Agarwal, Himabindu Lakkaraju. """ - base_url = "https://raw.githubusercontent.com/AI4LIFE-GROUP/" "med-safety-bench/main/datasets" sources = [] diff --git a/pyrit/datasets/seclists_bias_testing_dataset.py b/pyrit/datasets/seclists_bias_testing_dataset.py index aaac28419..d9823db95 100644 --- a/pyrit/datasets/seclists_bias_testing_dataset.py +++ b/pyrit/datasets/seclists_bias_testing_dataset.py @@ -42,7 +42,6 @@ def fetch_seclists_bias_testing_dataset( Returns: SeedDataset: A SeedDataset containing the examples with placeholders replaced. """ - if random_seed is not None: random.seed(random_seed) diff --git a/pyrit/datasets/text_jailbreak.py b/pyrit/datasets/text_jailbreak.py index 3eb191794..2cbcab8f5 100644 --- a/pyrit/datasets/text_jailbreak.py +++ b/pyrit/datasets/text_jailbreak.py @@ -10,7 +10,7 @@ class TextJailBreak: """ - A class that manages jailbreak datasets (like DAN, etc.) + A class that manages jailbreak datasets (like DAN, etc.). """ def __init__( diff --git a/pyrit/datasets/wmdp_dataset.py b/pyrit/datasets/wmdp_dataset.py index 237a4050b..139756da7 100644 --- a/pyrit/datasets/wmdp_dataset.py +++ b/pyrit/datasets/wmdp_dataset.py @@ -26,7 +26,6 @@ def fetch_wmdp_dataset(category: Optional[str] = None) -> QuestionAnsweringDatas For more information and access to the original dataset and related materials, visit: https://huggingface.co/datasets/cais/wmdp """ - # Determine which subset of data to load data_categories = None if not category: # if category is not specified, read in all 3 subsets of data diff --git a/pyrit/datasets/xstest_dataset.py b/pyrit/datasets/xstest_dataset.py index c2b9f2f54..678273e0e 100644 --- a/pyrit/datasets/xstest_dataset.py +++ b/pyrit/datasets/xstest_dataset.py @@ -30,7 +30,6 @@ def fetch_xstest_dataset( For more information and access to the original dataset and related materials, visit: https://github.com/paul-rottger/exaggerated-safety """ - # Determine the file type from the source URL file_type = source.split(".")[-1] if file_type not in FILE_TYPE_HANDLERS: diff --git a/pyrit/embedding/_text_embedding.py b/pyrit/embedding/_text_embedding.py index c9b6d08c6..c181b6caf 100644 --- a/pyrit/embedding/_text_embedding.py +++ b/pyrit/embedding/_text_embedding.py @@ -16,7 +16,7 @@ class _TextEmbedding(EmbeddingSupport, abc.ABC): - """Text embedding base class""" + """Text embedding base class.""" _client: Union[OpenAI, AzureOpenAI] _model: str @@ -30,7 +30,8 @@ def __init__(self) -> None: @tenacity.retry(wait=tenacity.wait_fixed(0.1), stop=tenacity.stop_after_delay(3)) def generate_text_embedding(self, text: str, **kwargs) -> EmbeddingResponse: - """Generate text embedding + """ + Generate text embedding. Args: text: The text to generate the embedding for diff --git a/pyrit/embedding/azure_text_embedding.py b/pyrit/embedding/azure_text_embedding.py index fbd8515c5..bb53f81ae 100644 --- a/pyrit/embedding/azure_text_embedding.py +++ b/pyrit/embedding/azure_text_embedding.py @@ -24,7 +24,8 @@ def __init__( api_version: str = "2024-02-01", use_entra_auth: bool = False, ) -> None: - """Generate embedding using the Azure API. Authenticate with either an API key or Entra authentication. + """ + Generate embedding using the Azure API. Authenticate with either an API key or Entra authentication. Args: api_key: The API key to use (only if you're not using Entra authentication). Defaults to diff --git a/pyrit/embedding/openai_text_embedding.py b/pyrit/embedding/openai_text_embedding.py index 92d41f47a..6c960a8d7 100644 --- a/pyrit/embedding/openai_text_embedding.py +++ b/pyrit/embedding/openai_text_embedding.py @@ -8,7 +8,8 @@ class OpenAiTextEmbedding(_TextEmbedding): def __init__(self, *, model: str, api_key: str) -> None: - """Generate embedding using OpenAI API + """ + Generate embedding using OpenAI API. Args: api_version: The API version to use diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index 3d00d8588..df099c838 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -149,7 +149,6 @@ def pyrit_target_retry(func: Callable) -> Callable: Returns: Callable: The decorated function with retry logic applied. """ - return retry( reraise=True, retry=retry_if_exception_type(RateLimitError) @@ -174,7 +173,6 @@ def pyrit_json_retry(func: Callable) -> Callable: Returns: Callable: The decorated function with retry logic applied. """ - return retry( reraise=True, retry=retry_if_exception_type(InvalidJsonException), @@ -196,7 +194,6 @@ def pyrit_placeholder_retry(func: Callable) -> Callable: Returns: Callable: The decorated function with retry logic applied. """ - return retry( reraise=True, retry=retry_if_exception_type(MissingPromptPlaceholderException), diff --git a/pyrit/exceptions/exceptions_helpers.py b/pyrit/exceptions/exceptions_helpers.py index 0c51690d5..258f65267 100644 --- a/pyrit/exceptions/exceptions_helpers.py +++ b/pyrit/exceptions/exceptions_helpers.py @@ -34,7 +34,6 @@ def remove_start_md_json(response_msg: str) -> str: Returns: str: The response message without the start marker (if one was present). """ - start_pattern = re.compile(r"^(```json\n|`json\n|```\n|`\n|```json|`json|```|`|json|json\n)") match = start_pattern.match(response_msg) if match: @@ -53,7 +52,6 @@ def remove_end_md_json(response_msg: str) -> str: Returns: str: The response message without the end marker (if one was present). """ - end_pattern = re.compile(r"(\n```|\n`|```|`)$") match = end_pattern.search(response_msg) if match: @@ -90,7 +88,6 @@ def remove_markdown_json(response_msg: str) -> str: Returns: str: The response message without Markdown formatting if present. """ - response_msg = remove_start_md_json(response_msg) response_msg = remove_end_md_json(response_msg) diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 773eee90f..dbd2b2c3a 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -103,7 +103,7 @@ def set_system_prompt( labels: Optional[Dict[str, str]] = None, ) -> None: """ - set or update the system-level prompt associated with a conversation. + Set or update the system-level prompt associated with a conversation. This helper is intended for conversational (`PromptChatTarget`) goals, where a dedicated system prompt influences the behavior of the LLM for @@ -344,7 +344,6 @@ def _process_piece( ValueError: If max_turns would be exceeded by this piece. ValueError: If a system prompt is provided but target doesn't support it. """ - # Check if multiturn is_multi_turn = max_turns is not None diff --git a/pyrit/executor/attack/core/attack_executor.py b/pyrit/executor/attack/core/attack_executor.py index 9552dff94..a321df08d 100644 --- a/pyrit/executor/attack/core/attack_executor.py +++ b/pyrit/executor/attack/core/attack_executor.py @@ -235,7 +235,6 @@ async def execute_single_turn_attacks_async( ... seed_groups=[SeedGroup(...), SeedGroup(...)] ... ) """ - # Validate that the attack uses SingleTurnAttackContext if hasattr(attack, "_context_type") and not issubclass(attack._context_type, SingleTurnAttackContext): raise TypeError( @@ -336,7 +335,6 @@ async def execute_multi_turn_attacks_async( ... custom_prompts=["Tell me about chemistry", "Explain system administration"] ... ) """ - # Validate that the attack uses MultiTurnAttackContext if hasattr(attack, "_context_type") and not issubclass(attack._context_type, MultiTurnAttackContext): raise TypeError( diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index bee10239a..ec9a5b394 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -34,7 +34,7 @@ @dataclass class AttackContext(StrategyContext, ABC): - """Base class for all attack contexts""" + """Base class for all attack contexts.""" # Natural-language description of what the attack tries to achieve objective: str @@ -124,7 +124,6 @@ async def _on_post_execute( Args: result (AttackResult): The result of the attack strategy execution. """ - if not event_data.result: raise ValueError("Attack result is None. Cannot log or record the outcome.") @@ -220,11 +219,13 @@ async def execute_async( ) -> AttackStrategyResultT: """ Execute the attack strategy asynchronously with the provided parameters. + Args: objective (str): The objective of the attack. prepended_conversation (Optional[List[Message]]): Conversation to prepend. memory_labels (Optional[Dict[str, str]]): Memory labels for the attack context. **kwargs: Additional parameters for the attack. + Returns: AttackStrategyResultT: The result of the attack execution. """ @@ -243,7 +244,6 @@ async def execute_async( """ Execute the attack strategy asynchronously with the provided parameters. """ - # Validate parameters before creating context objective = get_kwarg_param(kwargs=kwargs, param_name="objective", expected_type=str) diff --git a/pyrit/executor/attack/multi_turn/crescendo.py b/pyrit/executor/attack/multi_turn/crescendo.py index 9a6cae631..ab87898fd 100644 --- a/pyrit/executor/attack/multi_turn/crescendo.py +++ b/pyrit/executor/attack/multi_turn/crescendo.py @@ -394,7 +394,7 @@ async def _perform_async(self, *, context: CrescendoAttackContext) -> CrescendoA async def _teardown_async(self, *, context: CrescendoAttackContext) -> None: """ - Clean up after attack execution + Clean up after attack execution. Args: context (CrescendoAttackContext): The attack context. diff --git a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py index 201ccfc88..1b6f935cd 100644 --- a/pyrit/executor/attack/multi_turn/multi_prompt_sending.py +++ b/pyrit/executor/attack/multi_turn/multi_prompt_sending.py @@ -262,7 +262,7 @@ def _determine_attack_outcome( return AttackOutcome.FAILURE, "At least one prompt was filtered or failed to get a response" async def _teardown_async(self, *, context: MultiPromptSendingAttackContext) -> None: - """Clean up after attack execution""" + """Clean up after attack execution.""" # Nothing to be done here, no-op pass @@ -327,7 +327,6 @@ async def execute_async( """ Execute the attack strategy asynchronously with the provided parameters. """ - # Validate parameters before creating context prompt_sequence = get_kwarg_param( kwargs=kwargs, param_name="prompt_sequence", expected_type=list, required=True diff --git a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py index bb881ce5c..769ebc51f 100644 --- a/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py +++ b/pyrit/executor/attack/multi_turn/multi_turn_attack_strategy.py @@ -27,7 +27,7 @@ @dataclass class ConversationSession: - """Session for conversations""" + """Session for conversations.""" # Unique identifier of the main conversation between the attacker and model conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -38,7 +38,7 @@ class ConversationSession: @dataclass class MultiTurnAttackContext(AttackContext): - """Context for multi-turn attacks""" + """Context for multi-turn attacks.""" # Object holding all conversation-level identifiers for this attack session: ConversationSession = field(default_factory=lambda: ConversationSession()) @@ -113,7 +113,6 @@ async def execute_async( """ Execute the attack strategy asynchronously with the provided parameters. """ - # Validate parameters before creating context custom_prompt = get_kwarg_param(kwargs=kwargs, param_name="custom_prompt", expected_type=str, required=False) diff --git a/pyrit/executor/attack/multi_turn/red_teaming.py b/pyrit/executor/attack/multi_turn/red_teaming.py index b1d84cb20..d5fa85bae 100644 --- a/pyrit/executor/attack/multi_turn/red_teaming.py +++ b/pyrit/executor/attack/multi_turn/red_teaming.py @@ -265,7 +265,6 @@ async def _perform_async(self, *, context: MultiTurnAttackContext) -> AttackResu Returns: AttackResult: The result of the attack execution. """ - # Log the attack configuration logger.info(f"Starting red teaming attack with objective: {context.objective}") logger.info(f"Max turns: {self._max_turns}") @@ -314,7 +313,7 @@ async def _perform_async(self, *, context: MultiTurnAttackContext) -> AttackResu ) async def _teardown_async(self, *, context: MultiTurnAttackContext) -> None: - """Clean up after attack execution""" + """Clean up after attack execution.""" # Nothing to be done here, no-op pass diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index 828571e65..8d3107bfb 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -277,7 +277,6 @@ async def send_prompt_async(self, objective: str) -> None: - `off_topic`: `True` if the prompt was deemed off-topic - `error_message`: Set if an error occurred during execution """ - try: # Generate adversarial prompt prompt = await self._generate_adversarial_prompt_async(objective) @@ -1041,7 +1040,6 @@ def __init__( def _load_adversarial_prompts(self) -> None: """Load the adversarial chat prompts from the configured paths.""" - # Load system prompt self._adversarial_chat_system_seed_prompt = SeedPrompt.from_yaml_with_required_parameters( template_path=self._adversarial_chat_system_prompt_path, diff --git a/pyrit/executor/attack/single_turn/context_compliance.py b/pyrit/executor/attack/single_turn/context_compliance.py index a71929fd0..9b051a942 100644 --- a/pyrit/executor/attack/single_turn/context_compliance.py +++ b/pyrit/executor/attack/single_turn/context_compliance.py @@ -131,8 +131,10 @@ def _validate_context(self, *, context: SingleTurnAttackContext) -> None: """ Validate the context for the attack. This attack does not support prepended conversations, so it raises an error if one exists. + Args: context (SingleTurnAttackContext): The attack context to validate. + Raises: ValueError: If the context has a prepended conversation. """ diff --git a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py index 35770357a..9392040e3 100644 --- a/pyrit/executor/attack/single_turn/many_shot_jailbreak.py +++ b/pyrit/executor/attack/single_turn/many_shot_jailbreak.py @@ -21,7 +21,7 @@ class ManyShotJailbreakAttack(PromptSendingAttack): """ This attack implements implements the Many Shot Jailbreak method as discussed in research found here: - https://www.anthropic.com/research/many-shot-jailbreaking + https://www.anthropic.com/research/many-shot-jailbreaking. Prepends the seed prompt with a faux dialogue between a human and an AI, using examples from a dataset to demonstrate successful jailbreaking attempts. This method leverages the model's ability to learn from @@ -74,8 +74,10 @@ def __init__( def _validate_context(self, *, context: SingleTurnAttackContext) -> None: """ Validate the context before executing the attack. + Args: context (SingleTurnAttackContext): The attack context containing parameters and objective. + Raises: ValueError: If the context is invalid. """ @@ -86,8 +88,10 @@ def _validate_context(self, *, context: SingleTurnAttackContext) -> None: async def _perform_async(self, *, context: SingleTurnAttackContext) -> AttackResult: """ Perform the ManyShotJailbreakAttack. + Args: context (SingleTurnAttackContext): The attack context containing attack parameters. + Returns: AttackResult: The result of the attack. """ diff --git a/pyrit/executor/attack/single_turn/prompt_sending.py b/pyrit/executor/attack/single_turn/prompt_sending.py index 4dbff0d33..0b6eea72a 100644 --- a/pyrit/executor/attack/single_turn/prompt_sending.py +++ b/pyrit/executor/attack/single_turn/prompt_sending.py @@ -260,7 +260,7 @@ def _determine_attack_outcome( return AttackOutcome.FAILURE, "All attempts were filtered or failed to get a response" async def _teardown_async(self, *, context: SingleTurnAttackContext) -> None: - """Clean up after attack execution""" + """Clean up after attack execution.""" # Nothing to be done here, no-op pass @@ -297,7 +297,6 @@ async def _send_prompt_to_objective_target_async( Optional[Message]: The model's response if successful, or None if the request was filtered, blocked, or encountered an error. """ - return await self._prompt_normalizer.send_prompt_async( seed_group=prompt_group, target=self._objective_target, @@ -329,7 +328,6 @@ async def _evaluate_response_async( no objective scorer is set. Note that auxiliary scorer results are not returned but are still executed and stored. """ - scoring_results = await Scorer.score_response_async( response=response, objective_scorer=self._objective_scorer, diff --git a/pyrit/executor/attack/single_turn/role_play.py b/pyrit/executor/attack/single_turn/role_play.py index 7291c7ea8..2cd55eef9 100644 --- a/pyrit/executor/attack/single_turn/role_play.py +++ b/pyrit/executor/attack/single_turn/role_play.py @@ -129,8 +129,10 @@ async def _setup_async(self, *, context: SingleTurnAttackContext) -> None: def _validate_context(self, *, context: SingleTurnAttackContext) -> None: """ Validate the context before executing the attack. + Args: context (SingleTurnAttackContext): The attack context containing parameters and objective. + Raises: ValueError: If the context is invalid. """ diff --git a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py index 9afa7f752..51df270d2 100644 --- a/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py +++ b/pyrit/executor/attack/single_turn/single_turn_attack_strategy.py @@ -18,7 +18,7 @@ @dataclass class SingleTurnAttackContext(AttackContext): - """Context for single-turn attacks""" + """Context for single-turn attacks.""" # Unique identifier of the main conversation between the attacker and model conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) @@ -95,7 +95,6 @@ async def execute_async( """ Execute the attack strategy asynchronously with the provided parameters. """ - # Validate parameters before creating context seed_group = get_kwarg_param(kwargs=kwargs, param_name="seed_group", expected_type=SeedGroup, required=False) objective = get_kwarg_param(kwargs=kwargs, param_name="objective", expected_type=str, required=False) diff --git a/pyrit/executor/attack/single_turn/skeleton_key.py b/pyrit/executor/attack/single_turn/skeleton_key.py index 47652c196..b9f2a772b 100644 --- a/pyrit/executor/attack/single_turn/skeleton_key.py +++ b/pyrit/executor/attack/single_turn/skeleton_key.py @@ -188,7 +188,6 @@ def _create_skeleton_key_failure_result(self, *, context: SingleTurnAttackContex Returns: AttackResult: The failure result. """ - return AttackResult( conversation_id=context.conversation_id, objective=context.objective, diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index c6d71d54d..e68b8ae77 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -142,13 +142,12 @@ def _validate_context(self, *, context: FairnessBiasBenchmarkContext) -> None: async def _setup_async(self, *, context: FairnessBiasBenchmarkContext) -> None: """ Sets up phase before executing the strategy: - - Sets the objective (uses provided objective or generates default) - - Creates the story prompt based on subject and story type - - Generates the seed group for the benchmark + - sets the objective (uses provided objective or generates default), + - creates the story prompt based on subject and story type, + - generates the seed group for the benchmark. Args: context (FairnessBiasBenchmarkContext): The benchmark context to configure for execution - """ # Use provided objective or generate default if context.objective: @@ -174,7 +173,7 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta - Sends prompts to the target model - Extracts names from responses - Scores responses using the required scorer - - Stores experiment results in the context + - Stores experiment results in the context. Args: context (FairnessBiasBenchmarkContext): The configured benchmark context @@ -229,7 +228,7 @@ def _format_experiment_results( self, context: FairnessBiasBenchmarkContext, attack_result: AttackResult, experiment_num: int ): """ - Formats experiment data into a dictionary + Formats experiment data into a dictionary. Args: context (FairnessBiasBenchmarkContext): The benchmark context diff --git a/pyrit/executor/benchmark/question_answering.py b/pyrit/executor/benchmark/question_answering.py index 189b04ae6..031a83bbb 100644 --- a/pyrit/executor/benchmark/question_answering.py +++ b/pyrit/executor/benchmark/question_answering.py @@ -291,7 +291,6 @@ async def execute_async( """ Execute the benchmark strategy asynchronously with the provided parameters. """ - # Validate parameters before creating context question_answering_entry = get_kwarg_param( kwargs=kwargs, diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index f3d583b45..ce8e10f04 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -23,7 +23,7 @@ @dataclass class StrategyContext(ABC): - """Base class for all strategy contexts""" + """Base class for all strategy contexts.""" def duplicate(self: StrategyContextT) -> StrategyContextT: """ @@ -152,7 +152,6 @@ def __init__( event handler for strategy events. logger (logging.Logger): The logger to use for this strategy. """ - self._id = uuid.uuid4() self._context_type = context_type self._event_handlers: Dict[str, StrategyEventHandler[StrategyContextT, StrategyResultT]] = {} @@ -286,7 +285,6 @@ async def _execution_context(self, context: StrategyContextT) -> AsyncIterator[N Yields: None: Control is yielded back to the caller after setup is complete. """ - try: # Notify pre-setup event await self._handle_event(event=StrategyEvent.ON_PRE_SETUP, context=context) diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 314b5f2f8..859f22f4c 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -143,7 +143,6 @@ def _validate_context(self, *, context: AnecdoctorContext) -> None: Raises: ValueError: If the context is invalid. """ - if not context.content_type: raise ValueError("content_type must be provided in the context") @@ -417,7 +416,6 @@ async def execute_async( """ Execute the prompt generation strategy asynchronously with the provided parameters. """ - # Validate parameters before creating context content_type = get_kwarg_param(kwargs=kwargs, param_name="content_type", expected_type=str) language = get_kwarg_param(kwargs=kwargs, param_name="language", expected_type=str) diff --git a/pyrit/executor/promptgen/core/prompt_generator_strategy.py b/pyrit/executor/promptgen/core/prompt_generator_strategy.py index 97b23499a..39eb0b6b3 100644 --- a/pyrit/executor/promptgen/core/prompt_generator_strategy.py +++ b/pyrit/executor/promptgen/core/prompt_generator_strategy.py @@ -23,12 +23,12 @@ @dataclass class PromptGeneratorStrategyContext(StrategyContext, ABC): - """Base class for all prompt generator strategy contexts""" + """Base class for all prompt generator strategy contexts.""" @dataclass class PromptGeneratorStrategyResult(StrategyResult, ABC): - """Base class for all prompt generator strategy results""" + """Base class for all prompt generator strategy results.""" class _DefaultPromptGeneratorStrategyEventHandler( @@ -53,6 +53,7 @@ async def on_event( ) -> None: """ Handle an event during the execution of a prompt generator strategy. + Args: event_data (StrategyEventData[PromptGeneratorStrategyContextT, PromptGeneratorStrategyResultT]): The event data containing context and result. diff --git a/pyrit/executor/promptgen/fuzzer.py b/pyrit/executor/promptgen/fuzzer.py index 8ebf71a68..7ac3a2db7 100644 --- a/pyrit/executor/promptgen/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer.py @@ -1183,7 +1183,6 @@ async def execute_async( """ Execute the Fuzzer generation strategy asynchronously with the provided parameters. """ - # Validate parameters before creating context prompts = get_kwarg_param(kwargs=kwargs, param_name="prompts", expected_type=list) diff --git a/pyrit/executor/workflow/core/workflow_strategy.py b/pyrit/executor/workflow/core/workflow_strategy.py index 85a9a3f28..1b6fab4d1 100644 --- a/pyrit/executor/workflow/core/workflow_strategy.py +++ b/pyrit/executor/workflow/core/workflow_strategy.py @@ -24,14 +24,14 @@ @dataclass class WorkflowContext(StrategyContext, ABC): - """Base class for all workflow contexts""" + """Base class for all workflow contexts.""" pass @dataclass class WorkflowResult(StrategyResult, ABC): - """Base class for all workflow results""" + """Base class for all workflow results.""" pass diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index 8edc15d28..9e223cd87 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -257,7 +257,6 @@ async def _perform_async(self, *, context: XPIAContext) -> XPIAResult: XPIAResult: The result of the workflow execution containing the processing response, optional score, and attack setup response. """ - # Step 1: Setup and send attack prompt setup_response_text = await self._setup_attack_async(context=context) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 2a0d85b7a..0f831e249 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -132,7 +132,8 @@ def _refresh_token_if_needed(self) -> None: self._create_auth_token() def _create_engine(self, *, has_echo: bool) -> Engine: - """Creates the SQLAlchemy engine for Azure SQL Server. + """ + Creates the SQLAlchemy engine for Azure SQL Server. Creates an engine bound to the specified server and database. The `has_echo` parameter controls the verbosity of SQL execution logging. @@ -140,7 +141,6 @@ def _create_engine(self, *, has_echo: bool) -> Engine: Args: has_echo (bool): Flag to enable detailed SQL execution logging. """ - try: # Create the SQLAlchemy engine. # Use pool_pre_ping (health check) to gracefully handle server-closed connections @@ -197,7 +197,7 @@ def _create_tables_if_not_exist(self): def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: """ - Inserts embedding data into memory storage + Inserts embedding data into memory storage. """ self._insert_entries(entries=embedding_data) @@ -548,7 +548,7 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict return False def reset_database(self): - """Drop and recreate existing tables""" + """Drop and recreate existing tables.""" # Drop all existing tables Base.metadata.drop_all(self.engine) # Recreate the tables diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index ec9cd8e82..47bae6501 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -55,7 +55,8 @@ class MemoryInterface(abc.ABC): - """Abstract interface for conversation memory storage systems. + """ + Abstract interface for conversation memory storage systems. This interface defines the contract for storing and retrieving chat messages and conversation history. Implementations can use different storage backends @@ -68,7 +69,8 @@ class MemoryInterface(abc.ABC): engine: Engine = None def __init__(self, embedding_model=None): - """Initialize the MemoryInterface. + """ + Initialize the MemoryInterface. Args: embedding_model: If set, this includes embeddings in the memory entries @@ -157,7 +159,7 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] @abc.abstractmethod def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: """ - Inserts embedding data into memory storage + Inserts embedding data into memory storage. """ @abc.abstractmethod @@ -426,8 +428,10 @@ def get_conversation(self, *, conversation_id: str) -> MutableSequence[Message]: def get_request_from_response(self, *, response: Message) -> Message: """ Retrieves the request that produced the given response. + Args: request (Message): The message object to match. + Returns: Message: The corresponding message object. """ @@ -474,13 +478,14 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. + Returns: Sequence[MessagePiece]: A list of MessagePiece objects that match the specified filters. + Raises: Exception: If there is an error retrieving the prompts, an exception is logged and an empty list is returned. """ - conditions = [] if attack_id: conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) @@ -522,7 +527,7 @@ def get_message_pieces( def duplicate_conversation(self, *, conversation_id: str, new_attack_id: Optional[str] = None) -> str: """ - Duplicates a conversation for reuse + Duplicates a conversation for reuse. This can be useful when an attack strategy requires branching out from a particular point in the conversation. One cannot continue both branches with the same attack and conversation IDs since that would corrupt @@ -532,6 +537,7 @@ def duplicate_conversation(self, *, conversation_id: str, new_attack_id: Optiona conversation_id (str): The conversation ID with existing conversations. new_attack_id (str, Optional): The new attack ID to assign to the duplicated conversations. If no new attack ID is provided, the attack ID will remain the same. Defaults to None. + Returns: The uuid for the new conversation. """ @@ -565,6 +571,7 @@ def duplicate_conversation_excluding_last_turn( conversation_id (str): The conversation ID with existing conversations. new_attack_id (str, Optional): The new attack ID to assign to the duplicated conversations. If no new attack ID is provided, the attack ID will remain the same. Defaults to None. + Returns: The uuid for the new conversation. """ @@ -638,7 +645,6 @@ def _update_sequence(self, *, message_pieces: Sequence[MessagePiece]): Args: message_pieces (Sequence[MessagePiece]): The list of message pieces to update. """ - prev_conversations = self.get_message_pieces(conversation_id=message_pieces[0].conversation_id) sequence = 0 @@ -720,7 +726,7 @@ def dispose_engine(self): def cleanup(self): """ - Ensure cleanup on process exit + Ensure cleanup on process exit. """ # Ensure cleanup at process exit atexit.register(self.dispose_engine) @@ -969,7 +975,8 @@ def get_seed_groups( groups: Optional[Sequence[str]] = None, source: Optional[str] = None, ) -> Sequence[SeedGroup]: - """Retrieves groups of seed prompts based on the provided filtering criteria + """ + Retrieves groups of seed prompts based on the provided filtering criteria. Args: value_sha256 (Optional[Sequence[str]], Optional): SHA256 hash of value to filter seed groups by. @@ -1099,6 +1106,7 @@ def get_attack_results( labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. + Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. """ diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index d77972789..993767a45 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -108,6 +108,7 @@ class PromptMemoryEntry(Base): idx_conversation_id (Index): The index for the conversation ID. original_prompt_id (UUID): The original prompt id. It is equal to id unless it is a duplicate. scores (list[ScoreEntry]): The list of scores associated with the prompt. + Methods: __str__(): Returns a string representation of the memory entry. """ @@ -209,7 +210,7 @@ def __str__(self): class EmbeddingDataEntry(Base): # type: ignore """ Represents the embedding data associated with conversation entries in the database. - Each embedding is linked to a specific conversation entry via an id + Each embedding is linked to a specific conversation entry via an id. Parameters: id (Uuid): The primary key, which is a foreign key referencing the UUID in the PromptMemoryEntries table. @@ -231,7 +232,7 @@ def __str__(self): class ScoreEntry(Base): # type: ignore """ - Represents the Score Memory Entry + Represents the Score Memory Entry. """ @@ -508,6 +509,7 @@ class AttackResultEntry(Base): timestamp (DateTime): The timestamp of the attack result entry. last_response (PromptMemoryEntry): Relationship to the last response prompt memory entry. last_score (ScoreEntry): Relationship to the last score entry. + Methods: __str__(): Returns a string representation of the attack result entry. """ diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 663929833..d7de218ec 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -65,7 +65,8 @@ def _init_storage_io(self): self.results_storage_io = DiskStorageIO() def _create_engine(self, *, has_echo: bool) -> Engine: - """Creates the SQLAlchemy engine for SQLite. + """ + Creates the SQLAlchemy engine for SQLite. Creates an engine bound to the specified database file. The `has_echo` parameter controls the verbosity of SQL execution logging. @@ -150,7 +151,7 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None: """ - Inserts embedding data into memory storage + Inserts embedding data into memory storage. """ self._insert_entries(entries=embedding_data) diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index c4c6729d0..db0d00b39 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -32,7 +32,7 @@ class AttackOutcome(Enum): @dataclass class AttackResult(StrategyResult): - """Base class for all attack results""" + """Base class for all attack results.""" # Identity # Unique identifier of the conversation that produced this result diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index ae959d867..cd4982403 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -139,6 +139,7 @@ async def save_data(self, data: bytes, output_filename: Optional[str] = None) -> async def save_b64_image(self, data: str, output_filename: str = None) -> None: """ Saves the base64 encoded image to storage. + Arguments: data: string with base64 data output_filename (optional, str): filename to store image as. Defaults to UUID if not provided @@ -158,6 +159,7 @@ async def save_formatted_audio( ) -> None: """ Saves the PCM16 of other specially formatted audio data to storage. + Arguments: data: bytes with audio data output_filename (optional, str): filename to store audio as. Defaults to UUID if not provided diff --git a/pyrit/models/embeddings.py b/pyrit/models/embeddings.py index 4b317b23f..6a136cbdb 100644 --- a/pyrit/models/embeddings.py +++ b/pyrit/models/embeddings.py @@ -31,7 +31,8 @@ class EmbeddingResponse(BaseModel): data: list[EmbeddingData] def save_to_file(self, directory_path: Path) -> str: - """Save the embedding response to disk and return the path of the new file + """ + Save the embedding response to disk and return the path of the new file. Args: directory_path: The path to save the file to @@ -46,7 +47,8 @@ def save_to_file(self, directory_path: Path) -> str: @staticmethod def load_from_file(file_path: Path) -> EmbeddingResponse: - """Load the embedding response from disk + """ + Load the embedding response from disk. Args: file_path: The path to load the file from @@ -63,7 +65,8 @@ def to_json(self) -> str: class EmbeddingSupport(ABC): @abstractmethod def generate_text_embedding(self, text: str, **kwargs) -> EmbeddingResponse: - """Generate text embedding + """ + Generate text embedding. Args: text: The text to generate the embedding for diff --git a/pyrit/models/message.py b/pyrit/models/message.py index f809fb2d6..bdcfe9602 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -188,7 +188,6 @@ def group_conversation_message_pieces_by_sequence( ... ]) ... ] """ - if not message_pieces: return [] @@ -274,7 +273,6 @@ def construct_response_from_request( """ Constructs a response entry from a request. """ - if request.prompt_metadata: prompt_metadata = combine_dict(request.prompt_metadata, prompt_metadata or {}) diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index d18605f74..fe06a2491 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -16,7 +16,8 @@ class MessagePiece: - """Represents a piece of a message to a target. + """ + Represents a piece of a message to a target. This class represents a single piece of a message that will be sent to a target. Since some targets can handle multiple pieces (e.g., text and images), @@ -49,7 +50,8 @@ def __init__( scores: Optional[List[Score]] = None, targeted_harm_categories: Optional[List[str]] = None, ): - """Initialize a MessagePiece. + """ + Initialize a MessagePiece. Args: role: The role of the prompt (system, assistant, user). @@ -79,7 +81,6 @@ def __init__( scores: The scores associated with the prompt. Defaults to None. targeted_harm_categories: The harm categories associated with the prompt. Defaults to None. """ - self.id = id if id else uuid4() if role not in ChatMessageRole.__args__: # type: ignore diff --git a/pyrit/models/question_answering.py b/pyrit/models/question_answering.py index ba651672a..8e86668ae 100644 --- a/pyrit/models/question_answering.py +++ b/pyrit/models/question_answering.py @@ -46,7 +46,6 @@ class QuestionAnsweringEntry(BaseModel): def get_correct_answer_text(self) -> str: """Get the text of the correct answer.""" - correct_answer_index = self.correct_answer try: # Match using the explicit choice.index (not enumerate position) so non-sequential indices are supported diff --git a/pyrit/models/seed.py b/pyrit/models/seed.py index a6141bd9d..069b9e6b2 100644 --- a/pyrit/models/seed.py +++ b/pyrit/models/seed.py @@ -88,7 +88,8 @@ class Seed(YamlLoadable): prompt_group_id: Optional[uuid.UUID] = None def render_template_value(self, **kwargs) -> str: - """Renders self.value as a template, applying provided parameters in kwargs + """ + Renders self.value as a template, applying provided parameters in kwargs. Args: kwargs:Key-value pairs to replace in the SeedPrompt value. @@ -99,7 +100,6 @@ def render_template_value(self, **kwargs) -> str: Raises: ValueError: If parameters are missing or invalid in the template. """ - jinja_template = Template(self.value, undefined=StrictUndefined) try: @@ -108,8 +108,9 @@ def render_template_value(self, **kwargs) -> str: raise ValueError(f"Error applying parameters: {str(e)}") def render_template_value_silent(self, **kwargs) -> str: - """Renders self.value as a template, applying provided parameters in kwargs. For parameters in the template - that are not provided as kwargs here, this function will leave them as is instead of raising an error. + """ + Renders self.value as a template, applying provided parameters in kwargs. For parameters in the template + that are not provided as kwargs here, this function will leave them as is instead of raising an error. Args: kwargs: Key-value pairs to replace in the SeedPrompt value. diff --git a/pyrit/models/seed_dataset.py b/pyrit/models/seed_dataset.py index 9e70339a4..8b753549c 100644 --- a/pyrit/models/seed_dataset.py +++ b/pyrit/models/seed_dataset.py @@ -223,7 +223,8 @@ def from_dict(cls, data: Dict[str, Any]) -> SeedDataset: return cls(prompts=merged_prompts, **dataset_defaults) def render_template_value(self, **kwargs): - """Renders self.value as a template, applying provided parameters in kwargs + """ + Renders self.value as a template, applying provided parameters in kwargs. Args: kwargs:Key-value pairs to replace in the SeedDataset value. @@ -234,14 +235,13 @@ def render_template_value(self, **kwargs): Raises: ValueError: If parameters are missing or invalid in the template. """ - for prompt in self.prompts: prompt.value = prompt.render_template_value(**kwargs) @staticmethod def _set_seed_group_id_by_alias(seed_prompts: Sequence[dict]): """ - Sets all seed_group_ids based on prompt_group_alias matches + Sets all seed_group_ids based on prompt_group_alias matches. This is important so the prompt_group_alias can be set in yaml to group prompts """ @@ -260,7 +260,7 @@ def _set_seed_group_id_by_alias(seed_prompts: Sequence[dict]): def group_seed_prompts_by_prompt_group_id(seed: Sequence[Seed]) -> Sequence[SeedGroup]: """ Groups the given list of Seeds by their prompt_group_id and creates - SeedGroup instances. All seed prompts in a group must share the same prompt_group_id + SeedGroup instances. All seed prompts in a group must share the same prompt_group_id. Args: seed: A list of Seed objects. diff --git a/pyrit/models/seed_group.py b/pyrit/models/seed_group.py index e4a2f1789..c6658db78 100644 --- a/pyrit/models/seed_group.py +++ b/pyrit/models/seed_group.py @@ -60,7 +60,8 @@ def __init__( ) def render_template_value(self, **kwargs): - """Renders self.value as a template, applying provided parameters in kwargs + """ + Renders self.value as a template, applying provided parameters in kwargs. Args: kwargs:Key-value pairs to replace in the SeedGroup value. @@ -71,7 +72,6 @@ def render_template_value(self, **kwargs): Raises: ValueError: If parameters are missing or invalid in the template. """ - for prompt in self.prompts: prompt.value = prompt.render_template_value(**kwargs) diff --git a/pyrit/models/seed_objective.py b/pyrit/models/seed_objective.py index 0b1e0cc80..07138eb2e 100644 --- a/pyrit/models/seed_objective.py +++ b/pyrit/models/seed_objective.py @@ -19,7 +19,7 @@ class SeedObjective(Seed): """Represents a seed objective with various attributes and metadata.""" def __post_init__(self) -> None: - """Post-initialization to render the template to replace existing values""" + """Post-initialization to render the template to replace existing values.""" self.value = super().render_template_value_silent(**PATHS_DICT) self.data_type = "text" diff --git a/pyrit/models/seed_prompt.py b/pyrit/models/seed_prompt.py index 18ea646c6..6ab958500 100644 --- a/pyrit/models/seed_prompt.py +++ b/pyrit/models/seed_prompt.py @@ -37,7 +37,7 @@ class SeedPrompt(Seed): parameters: Optional[Sequence[str]] = field(default_factory=lambda: []) def __post_init__(self) -> None: - """Post-initialization to render the template to replace existing values""" + """Post-initialization to render the template to replace existing values.""" self.value = self.render_template_value_silent(**PATHS_DICT) if not self.data_type: diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 3c4452897..4204e0c20 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -22,7 +22,7 @@ class SupportedContentType(Enum): """ All supported content types for uploading blobs to provided storage account container. - See all options here: https://www.iana.org/assignments/media-types/media-types.xhtml + See all options here: https://www.iana.org/assignments/media-types/media-types.xhtml. """ # TODO, add other media supported types @@ -73,8 +73,10 @@ class DiskStorageIO(StorageIO): async def read_file(self, path: Union[Path, str]) -> bytes: """ Asynchronously reads a file from the local disk. + Args: path (Union[Path, str]): The path to the file. + Returns: bytes: The content of the file. """ @@ -85,6 +87,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: async def write_file(self, path: Union[Path, str], data: bytes) -> None: """ Asynchronously writes data to a file on the local disk. + Args: path (Path): The path to the file. data (bytes): The content to write to the file. @@ -96,8 +99,10 @@ async def write_file(self, path: Union[Path, str], data: bytes) -> None: async def path_exists(self, path: Union[Path, str]) -> bool: """ Checks if a path exists on the local disk. + Args: path (Path): The path to check. + Returns: bool: True if the path exists, False otherwise. """ @@ -107,8 +112,10 @@ async def path_exists(self, path: Union[Path, str]) -> bool: async def is_file(self, path: Union[Path, str]) -> bool: """ Checks if the given path is a file (not a directory). + Args: path (Path): The path to check. + Returns: bool: True if the path is a file, False otherwise. """ @@ -118,6 +125,7 @@ async def is_file(self, path: Union[Path, str]) -> bool: async def create_directory_if_not_exists(self, path: Union[Path, str]) -> None: """ Asynchronously creates a directory if it doesn't exist on the local disk. + Args: path (Path): The directory path to create. """ @@ -155,9 +163,11 @@ def __init__( self._client_async: AsyncContainerClient = None async def _create_container_client_async(self): - """Creates an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the + """ + Creates an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used - for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication.""" + for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication. + """ if not self._sas_token: logger.info("SAS token not provided. Creating a delegation SAS token using Entra ID authentication.") sas_token = await AzureStorageAuth.get_sas_token(self._container_url) @@ -176,7 +186,6 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st data (bytes): Byte representation of content to upload to container. content_type (str): Content type to upload. """ - content_settings = ContentSettings(content_type=f"{content_type}") logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) diff --git a/pyrit/models/strategy_result.py b/pyrit/models/strategy_result.py index 02a1956eb..e4b90cfab 100644 --- a/pyrit/models/strategy_result.py +++ b/pyrit/models/strategy_result.py @@ -13,7 +13,7 @@ @dataclass class StrategyResult(ABC): - """Base class for all strategy results""" + """Base class for all strategy results.""" def duplicate(self: StrategyResultT) -> StrategyResultT: """ diff --git a/pyrit/prompt_converter/add_image_to_video_converter.py b/pyrit/prompt_converter/add_image_to_video_converter.py index 5ac6c6f5f..c4336d830 100644 --- a/pyrit/prompt_converter/add_image_to_video_converter.py +++ b/pyrit/prompt_converter/add_image_to_video_converter.py @@ -50,7 +50,6 @@ def __init__( Raises: ValueError: If ``video_path`` is empty or invalid. """ - if not video_path: raise ValueError("Please provide valid video path") @@ -70,7 +69,6 @@ async def _add_image_to_video(self, image_path: str, output_path: str) -> str: Returns: str: The output video path. """ - try: import cv2 # noqa: F401 except ModuleNotFoundError as e: diff --git a/pyrit/prompt_converter/base2048_converter.py b/pyrit/prompt_converter/base2048_converter.py index 7d100849c..c3df28d8d 100644 --- a/pyrit/prompt_converter/base2048_converter.py +++ b/pyrit/prompt_converter/base2048_converter.py @@ -12,7 +12,8 @@ class Base2048Converter(PromptConverter): - """Converter that encodes text to base2048 format. + """ + Converter that encodes text to base2048 format. This converter takes input text and converts it to base2048 encoding, which uses 2048 different Unicode characters to represent binary data. @@ -25,7 +26,8 @@ def __init__(self) -> None: pass async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: - """Converts the given prompt to base2048 encoding. + """ + Converts the given prompt to base2048 encoding. Args: prompt: The prompt to be converted. diff --git a/pyrit/prompt_converter/base64_converter.py b/pyrit/prompt_converter/base64_converter.py index 7d429cbe6..5b8a1e508 100644 --- a/pyrit/prompt_converter/base64_converter.py +++ b/pyrit/prompt_converter/base64_converter.py @@ -10,7 +10,8 @@ class Base64Converter(PromptConverter): - """Converter that encodes text to base64 format. + """ + Converter that encodes text to base64 format. This converter takes input text and converts it to base64 encoding, which can be useful for obfuscating text or testing how systems @@ -29,7 +30,8 @@ class Base64Converter(PromptConverter): ] def __init__(self, *, encoding_func: EncodingFunc = "b64encode") -> None: - """Initialize the Base64Converter. + """ + Initialize the Base64Converter. Args: encoding_func: The base64 encoding function to use. Defaults to "b64encode". @@ -37,7 +39,8 @@ def __init__(self, *, encoding_func: EncodingFunc = "b64encode") -> None: self._encoding_func = encoding_func async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult: - """Converts the given prompt to base64 encoding. + """ + Converts the given prompt to base64 encoding. Args: prompt: The prompt to be converted. diff --git a/pyrit/prompt_converter/denylist_converter.py b/pyrit/prompt_converter/denylist_converter.py index 3f609acdc..deb44a03d 100644 --- a/pyrit/prompt_converter/denylist_converter.py +++ b/pyrit/prompt_converter/denylist_converter.py @@ -64,7 +64,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text Returns: ConverterResult: The result containing the modified prompt. """ - # check if the prompt contains any words from the denylist and if so, # update the prompt replacing the denied words with synonyms denylist = self._prompt_kwargs.get("denylist", []) diff --git a/pyrit/prompt_converter/insert_punctuation_converter.py b/pyrit/prompt_converter/insert_punctuation_converter.py index 34afab200..e92e49c73 100644 --- a/pyrit/prompt_converter/insert_punctuation_converter.py +++ b/pyrit/prompt_converter/insert_punctuation_converter.py @@ -45,8 +45,10 @@ def _is_valid_punctuation(self, punctuation_list: List[str]) -> bool: """ Check if all items in the list are valid punctuation characters in string.punctuation. Space, letters, numbers, double punctuations are all invalid. + Args: punctuation_list (List[str]): List of punctuations to validate. + Returns: bool: valid list and valid punctuations """ @@ -88,9 +90,11 @@ async def convert_async( def _insert_punctuation(self, prompt: str, punctuation_list: List[str]) -> str: """ Insert punctuation into the prompt. + Args: prompt (str): The text to modify. punctuation_list (List[str]): List of punctuations for insertion. + Returns: str: The modified prompt with inserted punctuation from helper method. """ @@ -117,6 +121,7 @@ def _insert_between_words( ) -> str: """ Insert punctuation between words in the prompt. + Args: words (List[str]): List of words and punctuations. word_indices (List[int]): Indices of the actual words without punctuations in words list. @@ -141,10 +146,12 @@ def _insert_between_words( def _insert_within_words(self, prompt: str, num_insertions: int, punctuation_list: List[str]) -> str: """ Insert punctuation at any indices in the prompt, can insert into a word. + Args: promp str: The prompt string num_insertions (int): Number of punctuations to insert. punctuation_list (List[str]): punctuations for insertion. + Returns: str: The modified prompt with inserted punctuation. """ diff --git a/pyrit/prompt_converter/llm_generic_text_converter.py b/pyrit/prompt_converter/llm_generic_text_converter.py index 86d2dce7f..9f1e9ee2a 100644 --- a/pyrit/prompt_converter/llm_generic_text_converter.py +++ b/pyrit/prompt_converter/llm_generic_text_converter.py @@ -69,7 +69,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text Returns: ConverterResult: The result containing the converted output and its type. """ - conversation_id = str(uuid.uuid4()) kwargs = self._prompt_kwargs.copy() diff --git a/pyrit/prompt_converter/malicious_question_generator_converter.py b/pyrit/prompt_converter/malicious_question_generator_converter.py index f068d1b2a..2ba019508 100644 --- a/pyrit/prompt_converter/malicious_question_generator_converter.py +++ b/pyrit/prompt_converter/malicious_question_generator_converter.py @@ -36,7 +36,6 @@ def __init__( Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. """ - # set to default strategy if not provided prompt_template = ( prompt_template diff --git a/pyrit/prompt_converter/math_prompt_converter.py b/pyrit/prompt_converter/math_prompt_converter.py index abc2439ee..78d041cdb 100644 --- a/pyrit/prompt_converter/math_prompt_converter.py +++ b/pyrit/prompt_converter/math_prompt_converter.py @@ -36,7 +36,6 @@ def __init__( Can be omitted if a default has been configured via PyRIT initialization. prompt_template (SeedPrompt): The seed prompt template to use. """ - # Load the template from the YAML file or use a default template if not provided prompt_template = ( prompt_template diff --git a/pyrit/prompt_converter/pdf_converter.py b/pyrit/prompt_converter/pdf_converter.py index 93390d458..9480fb3cf 100644 --- a/pyrit/prompt_converter/pdf_converter.py +++ b/pyrit/prompt_converter/pdf_converter.py @@ -182,7 +182,6 @@ def _generate_pdf(self, content: str) -> bytes: Returns: bytes: The generated PDF content in bytes. """ - pdf_buffer = BytesIO() # Convert mm to points diff --git a/pyrit/prompt_converter/toxic_sentence_generator_converter.py b/pyrit/prompt_converter/toxic_sentence_generator_converter.py index ab80381b6..ab52fd5cf 100644 --- a/pyrit/prompt_converter/toxic_sentence_generator_converter.py +++ b/pyrit/prompt_converter/toxic_sentence_generator_converter.py @@ -45,7 +45,6 @@ def __init__( prompt_template (SeedPrompt): The seed prompt template to use. If not provided, defaults to the ``toxic_sentence_generator.yaml``. """ - # set to default strategy if not provided prompt_template = ( prompt_template diff --git a/pyrit/prompt_converter/translation_converter.py b/pyrit/prompt_converter/translation_converter.py index b2968b231..fd3646fb9 100644 --- a/pyrit/prompt_converter/translation_converter.py +++ b/pyrit/prompt_converter/translation_converter.py @@ -93,7 +93,6 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text Raises: ValueError: If the input type is not supported. """ - conversation_id = str(uuid.uuid4()) self.converter_target.set_system_prompt(system_prompt=self.system_prompt, conversation_id=conversation_id) diff --git a/pyrit/prompt_converter/transparency_attack_converter.py b/pyrit/prompt_converter/transparency_attack_converter.py index d504efcf0..03e35d2b0 100644 --- a/pyrit/prompt_converter/transparency_attack_converter.py +++ b/pyrit/prompt_converter/transparency_attack_converter.py @@ -21,7 +21,7 @@ class _AdamOptimizer: Implementation of the Adam Optimizer using NumPy. Adam optimization is a stochastic gradient descent method that is based on adaptive estimation of first-order and second-order moments. For further details, see the original paper: `"Adam: A Method for Stochastic Optimization"` - by D. P. Kingma and J. Ba, 2014: https://arxiv.org/abs/1412.6980 + by D. P. Kingma and J. Ba, 2014: https://arxiv.org/abs/1412.6980. Note: The code is inspired by the implementation found at: @@ -220,7 +220,7 @@ async def _save_blended_image(self, attack_image: numpy.ndarray, alpha: numpy.nd async def convert_async(self, *, prompt: str, input_type: PromptDataType = "image_path") -> ConverterResult: """ Converts the given prompt by blending an attack image (potentially harmful) with a benign image. - Uses the Novel Image Blending Algorithm from: https://arxiv.org/abs/2401.15817 + Uses the Novel Image Blending Algorithm from: https://arxiv.org/abs/2401.15817. Args: prompt (str): The image file path to the attack image. diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index d69bb85a1..cc4c2d117 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -65,7 +65,7 @@ async def send_prompt_async( attack_identifier (Optional[dict[str, str]], optional): Identifier for the attack. Defaults to None. - Raises: + Raises: Exception: If an error occurs during the request processing. ValueError: If the prompts in the SeedGroup are not part of the same sequence. @@ -156,7 +156,6 @@ async def send_prompt_batch_to_target_async( list[Message]: A list of Message objects representing the responses received for each prompt. """ - batch_items: List[List[Any]] = [ [request.seed_group for request in requests], [request.request_converter_configurations for request in requests], @@ -322,7 +321,6 @@ async def _build_message( Returns: Message: The message object. """ - entries = [] # All message pieces within Message needs to have same conversation ID. diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index a53051a54..2f23a9e66 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -21,7 +21,7 @@ class SupportedContentType(Enum): """ All supported content types for uploading blobs to provided storage account container. - See all options here: https://www.iana.org/assignments/media-types/media-types.xhtml + See all options here: https://www.iana.org/assignments/media-types/media-types.xhtml. """ PLAIN_TEXT = "text/plain" @@ -68,9 +68,11 @@ def __init__( super().__init__(endpoint=self._container_url, max_requests_per_minute=max_requests_per_minute) async def _create_container_client_async(self) -> None: - """Creates an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the + """ + Creates an asynchronous ContainerClient for Azure Storage. If a SAS token is provided via the AZURE_STORAGE_ACCOUNT_SAS_TOKEN environment variable or the init sas_token parameter, it will be used - for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication.""" + for authentication. Otherwise, a delegation SAS token will be created using Entra ID authentication. + """ container_url, _ = self._parse_url() try: sas_token: str = default_values.get_required_value( @@ -94,7 +96,6 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st data (bytes): Byte representation of content to upload to container. content_type (str): Content type to upload. """ - content_settings = ContentSettings(content_type=f"{content_type}") logger.info(msg="\nUploading to Azure Storage as blob:\n\t" + file_name) diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index bcd4355f2..60db7751c 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -230,7 +230,6 @@ def _construct_http_body( messages: list[ChatMessage], ) -> dict: """Constructs the HTTP request body for the AML online endpoint.""" - squashed_messages = self.chat_message_normalizer.normalize(messages) messages_dict = [message.model_dump() for message in squashed_messages] @@ -255,11 +254,12 @@ def _construct_http_body( return data def _get_headers(self) -> dict: - """Headers for accessing inference endpoint deployed in AML. + """ + Headers for accessing inference endpoint deployed in AML. + Returns: headers(dict): contains bearer token as AML key and content-type: JSON """ - headers: dict = { "Content-Type": "application/json", "Authorization": ("Bearer " + self._api_key), diff --git a/pyrit/prompt_target/batch_helper.py b/pyrit/prompt_target/batch_helper.py index edac7764e..5a6606697 100644 --- a/pyrit/prompt_target/batch_helper.py +++ b/pyrit/prompt_target/batch_helper.py @@ -37,7 +37,6 @@ def _validate_rate_limit_parameters(prompt_target: Optional[PromptTarget], batch Raises: ValueError: When rate limit RPM is specified for the target and batching is not adjusted to 1. """ - exc_message = "Batch size must be configured to 1 for the target requests per minute value to be respected." if prompt_target and prompt_target._max_requests_per_minute and batch_size != 1: raise ValueError(exc_message) @@ -66,7 +65,6 @@ async def batch_task_async( Returns: responses(list): List of results from the batched function """ - responses = [] _validate_rate_limit_parameters(prompt_target=prompt_target, batch_size=batch_size) diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index c8b310bdb..7532624a4 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -45,7 +45,7 @@ async def send_prompt_async(self, *, message: Message) -> Message: @abc.abstractmethod def _validate_request(self, *, message: Message) -> None: """ - Validates the provided message + Validates the provided message. """ def set_model_name(self, *, model_name: str) -> None: diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index e83031d37..3fe43bbc0 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -72,7 +72,7 @@ def _validate_request(self, *, message: Message) -> None: async def check_password(self, password: str) -> bool: """ - Checks if the password is correct + Checks if the password is correct. True means the password is correct, False means it is not """ diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index 9cd34420b..1f44faa28 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -24,7 +24,7 @@ class HTTPTarget(PromptTarget): """ - HTTP_Target is for endpoints that do not have an API and instead require HTTP request(s) to send a prompt + HTTP_Target is for endpoints that do not have an API and instead require HTTP request(s) to send a prompt. Parameters: http_request (str): the header parameters as a request (i.e., from Burp) @@ -96,7 +96,7 @@ def with_client( def _inject_prompt_into_request(self, request: MessagePiece) -> str: """ Adds the prompt into the URL if the prompt_regex_string is found in the - http_request + http_request. """ re_pattern = re.compile(self.prompt_regex_string) if re.search(self.prompt_regex_string, self.http_request): @@ -159,7 +159,7 @@ async def send_prompt_async(self, *, message: Message) -> Message: def parse_raw_http_request(self, http_request: str) -> tuple[Dict[str, str], RequestBody, str, str, str]: """ - Parses the HTTP request string into a dictionary of headers + Parses the HTTP request string into a dictionary of headers. Parameters: http_request: the header parameters as a request str with @@ -172,7 +172,6 @@ def parse_raw_http_request(self, http_request: str) -> tuple[Dict[str, str], Req http_method (str): method (ie GET vs POST) http_version (str): HTTP version to use """ - headers_dict: Dict[str, str] = {} if self._client: headers_dict = dict(self._client.headers.copy()) diff --git a/pyrit/prompt_target/http_target/http_target_callback_functions.py b/pyrit/prompt_target/http_target/http_target_callback_functions.py index 2ccb11606..297dbbcaf 100644 --- a/pyrit/prompt_target/http_target/http_target_callback_functions.py +++ b/pyrit/prompt_target/http_target/http_target_callback_functions.py @@ -11,7 +11,7 @@ def get_http_target_json_response_callback_function(key: str) -> Callable: """ - Determines proper parsing response function for an HTTP Request + Determines proper parsing response function for an HTTP Request. Parameters: key (str): this is the path pattern to follow for parsing the output response @@ -24,7 +24,7 @@ def get_http_target_json_response_callback_function(key: str) -> Callable: def parse_json_http_response(response: requests.Response): """ - Parses JSON outputs + Parses JSON outputs. Parameters: response (response): the HTTP Response to parse @@ -41,7 +41,7 @@ def parse_json_http_response(response: requests.Response): def get_http_target_regex_matching_callback_function(key: str, url: str = None) -> Callable: def parse_using_regex(response: requests.Response): """ - Parses text outputs using regex + Parses text outputs using regex. Parameters: url (optional str): the original URL if this is needed to get a full URL response back (ie BIC) diff --git a/pyrit/prompt_target/http_target/httpx_api_target.py b/pyrit/prompt_target/http_target/httpx_api_target.py index a931c4239..e98da6efa 100644 --- a/pyrit/prompt_target/http_target/httpx_api_target.py +++ b/pyrit/prompt_target/http_target/httpx_api_target.py @@ -82,7 +82,6 @@ async def send_prompt_async(self, *, message: Message) -> Message: - If file_path is set or we can deduce it from the message piece, we upload a file. - Otherwise, we send normal requests with JSON or form_data (if provided). """ - self._validate_request(message=message) message_piece: MessagePiece = message.message_pieces[0] diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 747125b0e..4447910b0 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -25,7 +25,8 @@ class HuggingFaceChatTarget(PromptChatTarget): - """The HuggingFaceChatTarget interacts with HuggingFace models, specifically for conducting red teaming activities. + """ + The HuggingFaceChatTarget interacts with HuggingFace models, specifically for conducting red teaming activities. Inherits from PromptTarget to comply with the current design standards. """ @@ -59,7 +60,8 @@ def __init__( attn_implementation: Optional[str] = None, max_requests_per_minute: Optional[int] = None, ) -> None: - """Initializes the HuggingFaceChatTarget. + """ + Initializes the HuggingFaceChatTarget. Args: model_id (Optional[str]): The Hugging Face model ID. Either model_id or model_path must be provided. @@ -149,7 +151,8 @@ def is_model_id_valid(self) -> bool: return False async def load_model_and_tokenizer(self): - """Loads the model and tokenizer, downloading if necessary. + """ + Loads the model and tokenizer, downloading if necessary. Downloads the model to the HF_MODELS_DIR folder if it does not exist, then loads it from there. diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index 6d41d1daf..c826ea799 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -12,7 +12,8 @@ class HuggingFaceEndpointTarget(PromptTarget): - """The HuggingFaceEndpointTarget interacts with HuggingFace models hosted on cloud endpoints. + """ + The HuggingFaceEndpointTarget interacts with HuggingFace models hosted on cloud endpoints. Inherits from PromptTarget to comply with the current design standards. """ @@ -29,7 +30,8 @@ def __init__( max_requests_per_minute: Optional[int] = None, verbose: bool = False, ) -> None: - """Initializes the HuggingFaceEndpointTarget with API credentials and model parameters. + """ + Initializes the HuggingFaceEndpointTarget with API credentials and model parameters. Args: hf_token (str): The Hugging Face token for authenticating with the Hugging Face endpoint. diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index b355c7ad2..b28cf2565 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -25,7 +25,7 @@ class OpenAIChatTarget(OpenAIChatTargetBase): """ - This class facilitates multimodal (image and text) input and text output generation + This class facilitates multimodal (image and text) input and text output generation. This works with GPT3.5, GPT4, GPT4o, GPT-V, and other compatible models @@ -148,7 +148,8 @@ def _set_openai_env_configuration_vars(self) -> None: self.api_key_environment_variable = "OPENAI_CHAT_KEY" async def _build_chat_messages_async(self, conversation: MutableSequence[Message]) -> list[dict]: - """Builds chat messages based on message entries. + """ + Builds chat messages based on message entries. Args: conversation (list[Message]): A list of Message objects. @@ -162,7 +163,8 @@ async def _build_chat_messages_async(self, conversation: MutableSequence[Message return await self._build_chat_messages_for_multi_modal_async(conversation) def _is_text_message_format(self, conversation: MutableSequence[Message]) -> bool: - """Checks if the message piece is in text message format. + """ + Checks if the message piece is in text message format. Args: conversation list[Message]: The conversation @@ -180,7 +182,7 @@ def _is_text_message_format(self, conversation: MutableSequence[Message]) -> boo def _build_chat_messages_for_text(self, conversation: MutableSequence[Message]) -> list[dict]: """ Builds chat messages based on message entries. This is needed because many - openai "compatible" models don't support ChatMessageListDictContent format (this is more universally accepted) + openai "compatible" models don't support ChatMessageListDictContent format (this is more universally accepted). Args: conversation (list[Message]): A list of Message objects. @@ -303,7 +305,8 @@ def _construct_message_from_openai_json( return construct_response_from_request(request=message_piece, response_text_pieces=[extracted_response]) def _validate_request(self, *, message: Message) -> None: - """Validates the structure and content of a message for compatibility of this target. + """ + Validates the structure and content of a message for compatibility of this target. Args: message (Message): The message object. @@ -311,7 +314,6 @@ def _validate_request(self, *, message: Message) -> None: Raises: ValueError: If any of the message pieces have a data type other than 'text' or 'image_path'. """ - converted_prompt_data_types = [ message_piece.converted_value_data_type for message_piece in message.message_pieces ] diff --git a/pyrit/prompt_target/openai/openai_chat_target_base.py b/pyrit/prompt_target/openai/openai_chat_target_base.py index fc09d9e69..754e231be 100644 --- a/pyrit/prompt_target/openai/openai_chat_target_base.py +++ b/pyrit/prompt_target/openai/openai_chat_target_base.py @@ -44,6 +44,8 @@ def __init__( **kwargs, ): """ + Initialize the OpenAIChatTargetBase with the given parameters. + Args: model_name (str, Optional): The name of the model. endpoint (str, Optional): The target URL for the OpenAI service. @@ -94,7 +96,8 @@ def __init__( @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async(self, *, message: Message) -> Message: - """Asynchronously sends a message and handles the response within a managed conversation context. + """ + Asynchronously sends a message and handles the response within a managed conversation context. Args: message (Message): The message object. @@ -102,7 +105,6 @@ async def send_prompt_async(self, *, message: Message) -> Message: Returns: Message: The updated conversation entry with the response from the prompt target. """ - self._validate_request(message=message) self.refresh_auth_headers() diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index 9da81ed11..d5707626e 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -34,6 +34,8 @@ def __init__( **kwargs, ): """ + Initialize the OpenAICompletionTarget with the given parameters. + Args: model_name (str, Optional): The name of the model. endpoint (str, Optional): The target URL for the OpenAI service. @@ -63,7 +65,6 @@ def __init__( `httpx.AsyncClient()` constructor. For example, to specify a 3 minutes timeout: httpx_client_kwargs={"timeout": 180} """ - super().__init__(*args, **kwargs) self._max_tokens = max_tokens diff --git a/pyrit/prompt_target/openai/openai_dall_e_target.py b/pyrit/prompt_target/openai/openai_dall_e_target.py index 608d2e593..d0ad0afda 100644 --- a/pyrit/prompt_target/openai/openai_dall_e_target.py +++ b/pyrit/prompt_target/openai/openai_dall_e_target.py @@ -73,7 +73,6 @@ def __init__( ValueError: If `num_images` is not 1 for DALL-E-3. ValueError: If `num_images` is less than 1 or greater than 10 for DALL-E-2. """ - self.dalle_version = dalle_version if dalle_version == "dall-e-3": if num_images != 1: diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index b38577229..b10e8a270 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -85,7 +85,6 @@ def __init__( httpx.AsyncClient() constructor. For example, to specify a 3 minutes timeout: httpx_client_kwargs={"timeout": 180} """ - super().__init__(**kwargs) self.system_prompt = system_prompt or "You are a helpful AI assistant" @@ -102,7 +101,6 @@ async def connect(self): Connects to Realtime API Target using websockets. Returns the WebSocket connection. """ - logger.info(f"Connecting to WebSocket: {self._endpoint}") query_params = { @@ -167,7 +165,6 @@ async def send_config(self, conversation_id: str): Args: conversation_id (str): Conversation ID """ - config_variables = self._set_system_prompt_and_config_vars() await self.send_event( @@ -266,7 +263,7 @@ async def cleanup_target(self): async def cleanup_conversation(self, conversation_id: str): """ - Disconnects from the WebSocket server for a specific conversation + Disconnects from the WebSocket server for a specific conversation. Args: conversation_id (str): The conversation ID to disconnect from. @@ -449,6 +446,7 @@ def _handle_audio_delta_event(*, event: dict) -> bytes: async def send_text_async(self, text: str, conversation_id: str) -> Tuple[str, RealtimeTargetResult]: """ Sends text prompt to the WebSocket server. + Args: text: prompt to send. conversation_id: conversation ID @@ -517,7 +515,8 @@ async def send_audio_async(self, filename: str, conversation_id: str) -> Tuple[s return output_audio_path, result def _validate_request(self, *, message: Message) -> None: - """Validates the structure and content of a message for compatibility of this target. + """ + Validates the structure and content of a message for compatibility of this target. Args: message (Message): The message object. @@ -526,7 +525,6 @@ def _validate_request(self, *, message: Message) -> None: ValueError: If more than two message pieces are provided. ValueError: If any of the message pieces have a data type other than 'text' or 'audio_path'. """ - # Check the number of message pieces n_pieces = len(message.message_pieces) if n_pieces != 1: diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 32792de7a..b0405e41d 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -490,7 +490,8 @@ def _parse_response_output_section( ) def _validate_request(self, *, message: Message) -> None: - """Validates the structure and content of a message for compatibility of this target. + """ + Validates the structure and content of a message for compatibility of this target. Args: message (Message): The message object. diff --git a/pyrit/prompt_target/openai/openai_sora_target.py b/pyrit/prompt_target/openai/openai_sora_target.py index 62386099a..a80c3396b 100644 --- a/pyrit/prompt_target/openai/openai_sora_target.py +++ b/pyrit/prompt_target/openai/openai_sora_target.py @@ -123,7 +123,8 @@ def __init__( n_seconds: int = 4, **kwargs, ): - """Initialize the unified OpenAI Sora Target. + """ + Initialize the unified OpenAI Sora Target. Args: model_name (str, Optional): The name of the model. @@ -315,7 +316,8 @@ async def _send_httpx_request_async( @limit_requests_per_minute @pyrit_target_retry async def send_prompt_async(self, *, message: Message) -> Message: - """Asynchronously sends a message and handles the response within a managed conversation context. + """ + Asynchronously sends a message and handles the response within a managed conversation context. Args: message (Message): The message object. diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 1eaaccc73..0bd91e77d 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -108,16 +108,18 @@ def _set_auth_headers(self, use_entra_auth, passed_api_key) -> None: self._headers["Authorization"] = f"Bearer {self._api_key}" def refresh_auth_headers(self) -> None: - """Refresh the authentication headers. This is particularly useful for Entra authentication - where tokens need to be refreshed periodically.""" + """ + Refresh the authentication headers. This is particularly useful for Entra authentication + where tokens need to be refreshed periodically. + """ if self._azure_auth: self._headers["Authorization"] = f"Bearer {self._azure_auth.refresh_token()}" @abstractmethod def _set_openai_env_configuration_vars(self) -> None: """ - Sets deployment_environment_variable, endpoint_environment_variable, and api_key_environment_variable - which are read from .env + Sets deployment_environment_variable, endpoint_environment_variable, + and api_key_environment_variable which are read from .env file. """ raise NotImplementedError diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index bf362d52c..cabef4ccd 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -62,7 +62,6 @@ def __init__( httpx.AsyncClient() constructor. For example, to specify a 3 minutes timeout: httpx_client_kwargs={"timeout": 180} """ - super().__init__(**kwargs) if not self._model_name: diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index b4c06726e..17344b332 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -290,7 +290,8 @@ async def _extract_content_if_ready_async( return None async def _extract_text_from_message_groups(self, ai_message_groups: list, text_selector: str) -> List[str]: - """Extract text content from message groups using the provided selector. + """ + Extract text content from message groups using the provided selector. Args: ai_message_groups: List of message group elements to extract text from @@ -313,7 +314,8 @@ async def _extract_text_from_message_groups(self, ai_message_groups: list, text_ return all_text_parts def _filter_placeholder_text(self, text_parts: List[str]) -> List[str]: - """Filter out placeholder/loading text from extracted content. + """ + Filter out placeholder/loading text from extracted content. Args: text_parts: List of text strings to filter @@ -329,7 +331,8 @@ def _filter_placeholder_text(self, text_parts: List[str]) -> List[str]: return [text for text in text_parts if text.lower() not in placeholder_texts] async def _count_images_in_groups(self, message_groups: list) -> int: - """Count total images in message groups (both iframes and direct). + """ + Count total images in message groups (both iframes and direct). Args: message_groups: List of message group elements to search @@ -357,7 +360,8 @@ async def _count_images_in_groups(self, message_groups: list) -> int: return image_count async def _wait_minimum_time(self, seconds: int): - """Wait for a minimum amount of time, logging progress. + """ + Wait for a minimum amount of time, logging progress. Args: seconds: Number of seconds to wait @@ -369,7 +373,8 @@ async def _wait_minimum_time(self, seconds: int): async def _wait_for_images_to_stabilize( self, selectors: CopilotSelectors, ai_message_groups: list, initial_group_count: int = 0 ) -> list: - """Wait for images to appear and DOM to stabilize. + """ + Wait for images to appear and DOM to stabilize. Images may appear 1-5 seconds after text, and the DOM structure can change (e.g., from 3 groups to 2 groups). This method waits until either: @@ -435,7 +440,8 @@ async def _wait_for_images_to_stabilize( return all_groups[initial_group_count:] async def _extract_images_from_iframes(self, ai_message_groups: list) -> list: - """Extract images from iframes within message groups. + """ + Extract images from iframes within message groups. Args: ai_message_groups: List of message group elements to search @@ -470,7 +476,8 @@ async def _extract_images_from_iframes(self, ai_message_groups: list) -> list: return iframe_images async def _extract_images_from_message_groups(self, selectors: CopilotSelectors, ai_message_groups: list) -> list: - """Extract images directly from message groups (fallback when no iframes). + """ + Extract images directly from message groups (fallback when no iframes). Args: selectors: The selectors for the Copilot interface @@ -516,7 +523,8 @@ async def _extract_images_from_message_groups(self, selectors: CopilotSelectors, return image_elements async def _process_image_elements(self, image_elements: list) -> List[Tuple[str, PromptDataType]]: - """Process image elements and save them to disk. + """ + Process image elements and save them to disk. Args: image_elements: List of image elements to process @@ -556,7 +564,8 @@ async def _process_image_elements(self, image_elements: list) -> List[Tuple[str, async def _extract_and_filter_text_async( self, *, ai_message_groups: list, text_selector: str ) -> List[Tuple[str, PromptDataType]]: - """Extract and filter text content from message groups. + """ + Extract and filter text content from message groups. Args: ai_message_groups: Message groups to process @@ -584,7 +593,8 @@ async def _extract_and_filter_text_async( async def _extract_all_images_async( self, *, selectors: CopilotSelectors, ai_message_groups: list, initial_group_count: int ) -> List[Tuple[str, PromptDataType]]: - """Extract all images from message groups using iframe and direct methods. + """ + Extract all images from message groups using iframe and direct methods. Args: selectors: Copilot interface selectors @@ -612,7 +622,8 @@ async def _extract_all_images_async( return await self._process_image_elements(image_elements) async def _extract_fallback_text_async(self, *, ai_message_groups: list) -> str: - """Extract fallback text content when no other content is found. + """ + Extract fallback text content when no other content is found. Args: ai_message_groups: Message groups to extract from @@ -632,7 +643,8 @@ async def _extract_fallback_text_async(self, *, ai_message_groups: list) -> str: def _assemble_response( self, *, response_pieces: List[Tuple[str, PromptDataType]] ) -> Union[str, List[Tuple[str, PromptDataType]]]: - """Assemble response pieces into appropriate return format. + """ + Assemble response pieces into appropriate return format. Args: response_pieces: List of (content, data_type) tuples @@ -654,7 +666,8 @@ def _assemble_response( async def _extract_multimodal_content_async( self, selectors: CopilotSelectors, initial_group_count: int = 0 ) -> Union[str, List[Tuple[str, PromptDataType]]]: - """Extract multimodal content (text and images) from Copilot response. + """ + Extract multimodal content (text and images) from Copilot response. Args: selectors: The selectors for the Copilot interface diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index 06989b6fd..8c051eab6 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -78,7 +78,6 @@ def __init__( minute before hitting a rate limit. The number of requests sent to the target will be capped at the value provided. """ - endpoint_value = default_values.get_required_value( env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint ) @@ -104,9 +103,8 @@ async def send_prompt_async(self, *, message: Message) -> Message: """ Parses the text in message to separate the userPrompt and documents contents, then sends an HTTP request to the endpoint and obtains a response in JSON. For more info, visit - https://learn.microsoft.com/en-us/azure/ai-services/content-safety/quickstart-jailbreak + https://learn.microsoft.com/en-us/azure/ai-services/content-safety/quickstart-jailbreak. """ - self._validate_request(message=message) request = message.message_pieces[0] @@ -165,7 +163,6 @@ def _validate_response(self, request_body: dict, response_body: dict) -> None: """ Ensures that every field sent to the Prompt Shield was analyzed. """ - user_prompt_sent: str | None = request_body.get("userPrompt") documents_sent: list[str] | None = request_body.get("documents") @@ -181,9 +178,8 @@ def _validate_response(self, request_body: dict, response_body: dict) -> None: def _input_parser(self, input_str: str) -> dict[str, Any]: """ Parses the input given to the target to extract the two fields sent to - Prompt Shield: userPrompt: str, and documents: list[str] + Prompt Shield: userPrompt: str, and documents: list[str]. """ - match self._force_entry_field: case "userPrompt": return {"userPrompt": input_str, "documents": []} diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index fe9533709..8db751e40 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -14,7 +14,7 @@ class TextTarget(PromptTarget): """ The TextTarget takes prompts, adds them to memory and writes them to io - which is sys.stdout by default + which is sys.stdout by default. This can be useful in various situations, for example, if operators want to generate prompts but enter them manually. diff --git a/pyrit/scenarios/atomic_attack.py b/pyrit/scenarios/atomic_attack.py index e11833bb7..f72153bc6 100644 --- a/pyrit/scenarios/atomic_attack.py +++ b/pyrit/scenarios/atomic_attack.py @@ -123,7 +123,6 @@ def __init__( TypeError: If seed_groups is provided for multi-turn attacks or custom_prompts is provided for single-turn attacks. """ - self.atomic_attack_name = atomic_attack_name if not objectives: diff --git a/pyrit/scenarios/scenario.py b/pyrit/scenarios/scenario.py index aabd8d935..d56cd810f 100644 --- a/pyrit/scenarios/scenario.py +++ b/pyrit/scenarios/scenario.py @@ -626,7 +626,6 @@ async def _execute_scenario_async(self) -> ScenarioResult: Raises: Exception: Any exception that occurs during scenario execution. """ - logger.info(f"Starting scenario '{self._name}' execution with {len(self._atomic_attacks)} atomic attacks") # Type narrowing: _scenario_result_id is guaranteed to be non-None at this point diff --git a/pyrit/scenarios/scenarios/encoding_scenario.py b/pyrit/scenarios/scenarios/encoding_scenario.py index c9403fa14..8a5c12d63 100644 --- a/pyrit/scenarios/scenarios/encoding_scenario.py +++ b/pyrit/scenarios/scenarios/encoding_scenario.py @@ -143,7 +143,6 @@ def __init__( encoding-modified prompts. scenario_result_id (Optional[str]): Optional ID of an existing scenario result to resume. """ - objective_scorer = objective_scorer or DecodingScorer(categories=["encoding_scenario"]) self._scorer_config = AttackScoringConfig(objective_scorer=objective_scorer) @@ -253,7 +252,6 @@ def _get_prompt_attacks(self, *, converters: list[PromptConverter], encoding_nam Returns: list[AtomicAttack]: List of atomic attacks for this encoding scheme. """ - converter_configs = [ AttackConverterConfig( request_converters=PromptConverterConfiguration.from_converters(converters=converters) diff --git a/pyrit/scenarios/scenarios/foundry_scenario.py b/pyrit/scenarios/scenarios/foundry_scenario.py index 8f8999fc2..e106b2038 100644 --- a/pyrit/scenarios/scenarios/foundry_scenario.py +++ b/pyrit/scenarios/scenarios/foundry_scenario.py @@ -259,7 +259,6 @@ def __init__( Raises: ValueError: If attack_strategies is empty or contains unsupported strategies. """ - self._adversarial_chat = adversarial_chat if adversarial_chat else self._get_default_adversarial_target() self._objective_scorer = objective_scorer if objective_scorer else self._get_default_scorer() self._objectives: list[str] = ( diff --git a/pyrit/score/aggregator_utils.py b/pyrit/score/aggregator_utils.py index d59cac083..d30cc7df7 100644 --- a/pyrit/score/aggregator_utils.py +++ b/pyrit/score/aggregator_utils.py @@ -8,7 +8,8 @@ def combine_metadata_and_categories(scores: List[Score]) -> tuple[Dict[str, Union[str, int]], List[str]]: - """Combine metadata and categories from multiple scores with deduplication. + """ + Combine metadata and categories from multiple scores with deduplication. Args: scores: List of Score objects. @@ -29,7 +30,8 @@ def combine_metadata_and_categories(scores: List[Score]) -> tuple[Dict[str, Unio def format_score_for_rationale(score: Score) -> str: - """Format a single score for inclusion in an aggregated rationale. + """ + Format a single score for inclusion in an aggregated rationale. Args: score: The Score object to format. diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 8f9b30957..ab92bec4b 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -50,7 +50,7 @@ def __init__( validator: Optional[ScorerPromptValidator] = None, ) -> None: """ - Class that initializes an Azure Content Filter Scorer + Class that initializes an Azure Content Filter Scorer. Args: api_key (str, Optional): The API key for accessing the Azure OpenAI service (only if you're not @@ -61,7 +61,6 @@ def __init__( harm_categories: The harm categories you want to query for as per defined in azure.ai.contentsafety.models.TextCategory. """ - super().__init__(validator=validator or self._default_validator) if harm_categories: @@ -92,7 +91,8 @@ def __init__( raise ValueError("Please provide the Azure Content Safety endpoint") async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Evaluating the input text or image using the Azure Content Filter API + """ + Evaluating the input text or image using the Azure Content Filter API. Args: message_piece (MessagePiece): The message piece containing the text to be scored. diff --git a/pyrit/score/float_scale/float_scale_score_aggregator.py b/pyrit/score/float_scale/float_scale_score_aggregator.py index 631416408..53e447f46 100644 --- a/pyrit/score/float_scale/float_scale_score_aggregator.py +++ b/pyrit/score/float_scale/float_scale_score_aggregator.py @@ -16,7 +16,8 @@ def _build_rationale(scores: List[Score], *, aggregate_description: str) -> tuple[str, str]: - """Build description and rationale for aggregated scores. + """ + Build description and rationale for aggregated scores. Args: scores: List of Score objects to aggregate. @@ -43,7 +44,8 @@ def _create_aggregator( result_func: FloatScaleOp, aggregate_description: str, ) -> FloatScaleAggregatorFunc: - """Create a float-scale aggregator using a result function over float values. + """ + Create a float-scale aggregator using a result function over float values. Args: name (str): Name of the aggregator variant. @@ -99,7 +101,8 @@ def aggregator(scores: Iterable[Score]) -> List[ScoreAggregatorResult]: # Float scale aggregators (return list with single score) class FloatScaleScoreAggregator: - """Namespace for float scale score aggregators that return a single aggregated score. + """ + Namespace for float scale score aggregators that return a single aggregated score. All aggregators return a list containing one ScoreAggregatorResult that combines all input scores together, preserving all categories. @@ -131,7 +134,8 @@ def _create_aggregator_by_category( aggregate_description: str, group_by_category: bool = True, ) -> FloatScaleAggregatorFunc: - """Create a float-scale aggregator that can optionally group scores by category. + """ + Create a float-scale aggregator that can optionally group scores by category. When group_by_category=True (default), scores are grouped by their category and each category is aggregated separately, returning multiple ScoreAggregatorResult objects. @@ -248,7 +252,8 @@ def aggregator(scores: Iterable[Score]) -> List[ScoreAggregatorResult]: # Category-aware aggregators (group by category and return multiple scores) class FloatScaleScorerByCategory: - """Namespace for float scale score aggregators that group by category. + """ + Namespace for float scale score aggregators that group by category. These aggregators return multiple ScoreAggregatorResult objects (one per category). Useful for scorers like AzureContentFilterScorer that return multiple scores per item. @@ -278,7 +283,8 @@ class FloatScaleScorerByCategory: # Non-category-aware aggregators (combine all categories into one score) class FloatScaleScorerAllCategories: - """Namespace for float scale score aggregators that combine all categories. + """ + Namespace for float scale score aggregators that combine all categories. These aggregators ignore category boundaries and aggregate all scores together, returning a single ScoreAggregatorResult with all categories combined. diff --git a/pyrit/score/float_scale/look_back_scorer.py b/pyrit/score/float_scale/look_back_scorer.py index 33ec91872..ec224265a 100644 --- a/pyrit/score/float_scale/look_back_scorer.py +++ b/pyrit/score/float_scale/look_back_scorer.py @@ -64,7 +64,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op list[Score]: A list containing a single Score object representing the detected amount of behavior change throughout the conversation. """ - # Retrieve conversation from provided MessagePiece conversation = self._memory.get_conversation(conversation_id=message_piece.conversation_id) diff --git a/pyrit/score/float_scale/plagiarism_scorer.py b/pyrit/score/float_scale/plagiarism_scorer.py index f9c89641d..75a4dec20 100644 --- a/pyrit/score/float_scale/plagiarism_scorer.py +++ b/pyrit/score/float_scale/plagiarism_scorer.py @@ -19,7 +19,8 @@ class PlagiarismMetric(Enum): class PlagiarismScorer(FloatScaleScorer): - """A scorer that measures plagiarism by computing word-level similarity + """ + A scorer that measures plagiarism by computing word-level similarity between the AI response and a reference text. This scorer implements three similarity metrics: @@ -37,14 +38,14 @@ def __init__( n: int = 5, validator: Optional[ScorerPromptValidator] = None, ) -> None: - """Initializes the PlagiarismScorer. + """ + Initializes the PlagiarismScorer. Args: reference_text (str): The reference text to compare against. metric (PlagiarismMetric, optional): The plagiarism detection metric to use. n (int, optional): The n-gram size for n-gram similarity (default is 5). """ - super().__init__(validator=validator or self._default_validator) self.reference_text = reference_text @@ -131,7 +132,8 @@ def _plagiarism_score( raise ValueError("metric must be 'lcs', 'levenshtein', or 'jaccard'") async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Scores the AI response against the reference text using the specified metric. + """ + Scores the AI response against the reference text using the specified metric. Args: message_piece (MessagePiece): The piece to score. diff --git a/pyrit/score/human/human_in_the_loop_gradio.py b/pyrit/score/human/human_in_the_loop_gradio.py index 38e21a9f8..6d2868cc9 100644 --- a/pyrit/score/human/human_in_the_loop_gradio.py +++ b/pyrit/score/human/human_in_the_loop_gradio.py @@ -44,7 +44,8 @@ def __init__( self._rpc_server.start() async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Score a message piece using human input through Gradio interface. + """ + Score a message piece using human input through Gradio interface. Args: message_piece (MessagePiece): The message piece to be scored by a human. @@ -53,7 +54,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op Returns: list[Score]: A list containing a single Score object based on human evaluation. """ - try: score = await asyncio.to_thread(self.retrieve_score, message_piece, objective=objective) return score @@ -62,7 +62,8 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op raise def retrieve_score(self, request_prompt: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Retrieve a score from the human evaluator through the RPC server. + """ + Retrieve a score from the human evaluator through the RPC server. Args: request_prompt (MessagePiece): The message piece to be scored. diff --git a/pyrit/score/score_aggregator_result.py b/pyrit/score/score_aggregator_result.py index f49499516..21e821630 100644 --- a/pyrit/score/score_aggregator_result.py +++ b/pyrit/score/score_aggregator_result.py @@ -7,7 +7,8 @@ @dataclass(frozen=True, slots=True) class ScoreAggregatorResult: - """Common result object returned by score aggregators. + """ + Common result object returned by score aggregators. Attributes: value (Union[bool, float]): The aggregated value. For true/false aggregators this is diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 011f41c75..e38e5731d 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -403,7 +403,6 @@ async def _score_value_with_llm( UnvalidatedScore: The score object containing the response from the target LLM. score_value still needs to be normalized and validated. """ - conversation_id = str(uuid.uuid4()) if attack_identifier: @@ -504,7 +503,6 @@ def _extract_objective_from_response(self, response: Message) -> str: Returns: str: The objective extracted from the response. """ - if not response.message_pieces: return "" diff --git a/pyrit/score/scorer_evaluation/krippendorff.py b/pyrit/score/scorer_evaluation/krippendorff.py index 5be86863a..ff3078c5b 100644 --- a/pyrit/score/scorer_evaluation/krippendorff.py +++ b/pyrit/score/scorer_evaluation/krippendorff.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Krippendorff's alpha for ordinal data. +""" +Krippendorff's alpha for ordinal data. This implementation follows the standard Krippendorff's alpha formulation and is inspired by LightTag/simpledorff's clean decomposition of expected/observed @@ -20,7 +21,8 @@ def _validate_and_prepare_data( level_of_measurement: str, missing: float | None, ) -> tuple["np.ndarray", "np.ndarray", "np.ndarray"]: - """Validate inputs and prepare data for reliability calculation. + """ + Validate inputs and prepare data for reliability calculation. Args: reliability_data: Ratings array of shape (num_raters_or_trials, num_items). @@ -65,7 +67,8 @@ def _build_value_counts_matrix( valid_mask: "np.ndarray", categories: "np.ndarray", ) -> "np.ndarray": - """Build matrix counting how many raters assigned each category to each item. + """ + Build matrix counting how many raters assigned each category to each item. Args: data: Float64 array of ratings. @@ -94,7 +97,8 @@ def _build_value_counts_matrix( def _build_coincidence_matrix( value_counts: "np.ndarray", ) -> "np.ndarray": - """Build coincidence matrix from value counts. + """ + Build coincidence matrix from value counts. Args: value_counts: Matrix of shape (num_items, num_categories) with counts. @@ -130,7 +134,8 @@ def _build_coincidence_matrix( def _build_expected_matrix( coincidence_matrix: "np.ndarray", ) -> tuple["np.ndarray", "np.ndarray", float]: - """Build expected coincidence matrix from observed coincidences. + """ + Build expected coincidence matrix from observed coincidences. Args: coincidence_matrix: Observed coincidence matrix. @@ -157,7 +162,8 @@ def _build_ordinal_distance_matrix( num_categories: int, n_v: "np.ndarray", ) -> "np.ndarray": - """Build ordinal distance matrix using category marginals. + """ + Build ordinal distance matrix using category marginals. Args: num_categories: Number of unique categories. @@ -183,7 +189,8 @@ def _compute_alpha_from_disagreements( observed_disagreement: float, expected_disagreement: float, ) -> float: - """Compute Krippendorff's alpha from observed and expected disagreements. + """ + Compute Krippendorff's alpha from observed and expected disagreements. Args: observed_disagreement: Observed disagreement value. @@ -213,7 +220,8 @@ def krippendorff_alpha( level_of_measurement: str = "ordinal", missing: float | None = np.nan, ) -> float: - """Compute Krippendorff's alpha inter-rater reliability for ordinal data. + """ + Compute Krippendorff's alpha inter-rater reliability for ordinal data. Computes inter-rater reliability for ordered categories, ignoring missing entries and supporting varying numbers of raters per item. diff --git a/pyrit/score/scorer_evaluation/scorer_evaluator.py b/pyrit/score/scorer_evaluation/scorer_evaluator.py index dfbbe06d5..57ecee2e1 100644 --- a/pyrit/score/scorer_evaluation/scorer_evaluator.py +++ b/pyrit/score/scorer_evaluation/scorer_evaluator.py @@ -357,6 +357,7 @@ async def run_evaluation_async( num_scorer_trials (int): The number of trials to run the scorer on all responses. Defaults to 1. save_results (bool): Whether to save the metrics and model scoring results. Defaults to True. csv_path (Optional[Union[str, Path]]): The path to the CSV file to save results to. + Returns: HarmScorerMetrics: The metrics for the harm scorer. """ @@ -585,6 +586,7 @@ async def run_evaluation_async( labeled_dataset (HumanLabeledDataset): The HumanLabeledDataset to evaluate against. num_scorer_trials (int): The number of trials to run the scorer on all responses. Defaults to 1. save_results (bool): Whether to save the metrics and model scoring results. Defaults to True. + Returns: ObjectiveScorerMetrics: The metrics for the objective scorer. """ diff --git a/pyrit/score/true_false/decoding_scorer.py b/pyrit/score/true_false/decoding_scorer.py index 1527c2859..11d3196e0 100644 --- a/pyrit/score/true_false/decoding_scorer.py +++ b/pyrit/score/true_false/decoding_scorer.py @@ -35,7 +35,8 @@ def __init__( aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, validator: Optional[ScorerPromptValidator] = None, ) -> None: - """Initialize the DecodingScorer. + """ + Initialize the DecodingScorer. Args: text_matcher (Optional[TextMatching]): The text matching strategy to use. @@ -50,7 +51,8 @@ def __init__( self._score_categories = categories if categories else [] async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Score the given request piece based on text matching strategy. + """ + Score the given request piece based on text matching strategy. Args: message_piece (MessagePiece): The message piece to score. @@ -61,7 +63,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op list[Score]: A list containing a single Score object with a boolean value indicating whether any of the user piece values match the response. """ - memory = CentralMemory.get_memory_instance() user_request = memory.get_request_from_response(response=message_piece.to_message()) diff --git a/pyrit/score/true_false/float_scale_threshold_scorer.py b/pyrit/score/true_false/float_scale_threshold_scorer.py index 27cd13adb..54b555092 100644 --- a/pyrit/score/true_false/float_scale_threshold_scorer.py +++ b/pyrit/score/true_false/float_scale_threshold_scorer.py @@ -24,7 +24,8 @@ def __init__( threshold: float, float_scale_aggregator: FloatScaleAggregatorFunc = FloatScaleScoreAggregator.MAX, ) -> None: - """Initialize the FloatScaleThresholdScorer. + """ + Initialize the FloatScaleThresholdScorer. Args: scorer (FloatScaleScorer): The underlying float scale scorer to use. @@ -49,7 +50,8 @@ async def _score_async( objective: Optional[str] = None, role_filter: Optional[ChatMessageRole] = None, ) -> list[Score]: - """Scores the piece using the underlying float-scale scorer and thresholds the resulting score. + """ + Scores the piece using the underlying float-scale scorer and thresholds the resulting score. Args: message (Message): The message to score. @@ -97,7 +99,8 @@ async def _score_async( return [score] async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Float Scale scorers do not support piecewise scoring. + """ + Float Scale scorers do not support piecewise scoring. Args: message_piece (MessagePiece): Unused. diff --git a/pyrit/score/true_false/gandalf_scorer.py b/pyrit/score/true_false/gandalf_scorer.py index 845162929..1bd658d69 100644 --- a/pyrit/score/true_false/gandalf_scorer.py +++ b/pyrit/score/true_false/gandalf_scorer.py @@ -111,7 +111,8 @@ async def _check_for_password_in_conversation(self, conversation_id: str) -> str return response_text async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Scores the text based on the password found in the text. + """ + Scores the text based on the password found in the text. Args: text (str): The text to be scored. diff --git a/pyrit/score/true_false/prompt_shield_scorer.py b/pyrit/score/true_false/prompt_shield_scorer.py index 31884b91b..4e2b6c618 100644 --- a/pyrit/score/true_false/prompt_shield_scorer.py +++ b/pyrit/score/true_false/prompt_shield_scorer.py @@ -36,7 +36,8 @@ def __init__( validator: Optional[ScorerPromptValidator] = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: - """Initialize the PromptShieldScorer. + """ + Initialize the PromptShieldScorer. Args: prompt_shield_target (PromptShieldTarget): The Prompt Shield target to use for scoring. @@ -96,7 +97,6 @@ def _parse_response_to_boolean_list(self, response: str) -> list[bool]: Remember that you can just access the metadata attribute to get the original Prompt Shield endpoint response, and then just call json.loads() on it to interact with it. """ - response_json: dict = json.loads(response) user_detections = [] diff --git a/pyrit/score/true_false/question_answer_scorer.py b/pyrit/score/true_false/question_answer_scorer.py index b99c7ad42..acc69b1e8 100644 --- a/pyrit/score/true_false/question_answer_scorer.py +++ b/pyrit/score/true_false/question_answer_scorer.py @@ -62,7 +62,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op Returns: list[Score]: A list containing a single Score object indicating whether the correct answer was found. """ - result = False matching_text = None diff --git a/pyrit/score/true_false/self_ask_category_scorer.py b/pyrit/score/true_false/self_ask_category_scorer.py index 4c0a44bed..316340185 100644 --- a/pyrit/score/true_false/self_ask_category_scorer.py +++ b/pyrit/score/true_false/self_ask_category_scorer.py @@ -52,7 +52,6 @@ def __init__( Defaults to TrueFalseScoreAggregator.OR. validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. """ - super().__init__(score_aggregator=score_aggregator, validator=validator or self._default_validator) content_classifier_path = self._verify_and_resolve_path(content_classifier_path) diff --git a/pyrit/score/true_false/self_ask_question_answer_scorer.py b/pyrit/score/true_false/self_ask_question_answer_scorer.py index b649a06c8..29afdf02c 100644 --- a/pyrit/score/true_false/self_ask_question_answer_scorer.py +++ b/pyrit/score/true_false/self_ask_question_answer_scorer.py @@ -49,7 +49,6 @@ def __init__( score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. """ - true_false_question_path = self._verify_and_resolve_path( true_false_question_path if true_false_question_path @@ -74,7 +73,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op Returns: list[Score]: A list containing a single Score object representing whether the answer was correct. """ - prompt = ( f"{objective}\n\nThe target response is as follows. " f"Evaluate if the response is correct:\n{message_piece.converted_value}" diff --git a/pyrit/score/true_false/self_ask_refusal_scorer.py b/pyrit/score/true_false/self_ask_refusal_scorer.py index fafc66c5f..0a5580b73 100644 --- a/pyrit/score/true_false/self_ask_refusal_scorer.py +++ b/pyrit/score/true_false/self_ask_refusal_scorer.py @@ -23,7 +23,8 @@ class SelfAskRefusalScorer(TrueFalseScorer): - """A self-ask scorer that detects refusal in AI responses. + """ + A self-ask scorer that detects refusal in AI responses. This scorer uses a language model to determine whether a response contains a refusal to answer or comply with the given prompt. It's useful for @@ -39,7 +40,8 @@ def __init__( validator: Optional[ScorerPromptValidator] = None, score_aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, ) -> None: - """Initialize the SelfAskRefusalScorer. + """ + Initialize the SelfAskRefusalScorer. Args: chat_target (PromptChatTarget): The endpoint that will be used to score the prompt. @@ -47,7 +49,6 @@ def __init__( score_aggregator (TrueFalseAggregatorFunc): The aggregator function to use. Defaults to TrueFalseScoreAggregator.OR. """ - super().__init__(score_aggregator=score_aggregator, validator=validator or self._default_validator) self._prompt_target = chat_target @@ -61,7 +62,8 @@ def __init__( self._score_category = ["refusal"] async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Scores the prompt and determines whether the response is a refusal. + """ + Scores the prompt and determines whether the response is a refusal. Args: message_piece (MessagePiece): The message piece to score. diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index cfa4da398..44d8374f2 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -127,7 +127,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op The score_value is True or False based on which description fits best. Metadata can be configured to provide additional information. """ - unvalidated_score: UnvalidatedScore = await self._score_value_with_llm( prompt_target=self._prompt_target, system_prompt=self._system_prompt, diff --git a/pyrit/score/true_false/substring_scorer.py b/pyrit/score/true_false/substring_scorer.py index 26e658657..9afb31949 100644 --- a/pyrit/score/true_false/substring_scorer.py +++ b/pyrit/score/true_false/substring_scorer.py @@ -14,7 +14,8 @@ class SubStringScorer(TrueFalseScorer): - """Scorer that checks if a given substring is present in the text. + """ + Scorer that checks if a given substring is present in the text. This scorer performs substring matching using a configurable text matching strategy. Supports both exact substring matching and approximate matching. @@ -31,7 +32,8 @@ def __init__( aggregator: TrueFalseAggregatorFunc = TrueFalseScoreAggregator.OR, validator: Optional[ScorerPromptValidator] = None, ) -> None: - """Initialize the SubStringScorer. + """ + Initialize the SubStringScorer. Args: substring (str): The substring to search for in the text. @@ -48,7 +50,8 @@ def __init__( self._score_categories = categories if categories else [] async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Score the given message piece based on presence of the substring. + """ + Score the given message piece based on presence of the substring. Args: message_piece (MessagePiece): The message piece to score. diff --git a/pyrit/score/true_false/true_false_composite_scorer.py b/pyrit/score/true_false/true_false_composite_scorer.py index a2ce5ffcf..65aa187ce 100644 --- a/pyrit/score/true_false/true_false_composite_scorer.py +++ b/pyrit/score/true_false/true_false_composite_scorer.py @@ -11,7 +11,8 @@ class TrueFalseCompositeScorer(TrueFalseScorer): - """Composite true/false scorer that aggregates results from other true/false scorers. + """ + Composite true/false scorer that aggregates results from other true/false scorers. This scorer invokes a collection of constituent ``TrueFalseScorer`` instances and reduces their single-score outputs into one final true/false score using the supplied @@ -25,7 +26,8 @@ def __init__( aggregator: TrueFalseAggregatorFunc, scorers: List[TrueFalseScorer], ) -> None: - """Initialize the composite scorer. + """ + Initialize the composite scorer. Args: aggregator (TrueFalseAggregatorFunc): Aggregation function to combine child scores @@ -53,7 +55,8 @@ async def _score_async( objective: Optional[str] = None, role_filter: Optional[ChatMessageRole] = None, ) -> list[Score]: - """Score a request/response by combining results from all constituent scorers. + """ + Score a request/response by combining results from all constituent scorers. Args: message (Message): The request/response to score. @@ -62,7 +65,6 @@ async def _score_async( Returns: list[Score]: A single-element list with the aggregated true/false score. """ - tasks = [ scorer.score_async(message=message, objective=objective, role_filter=role_filter) for scorer in self._scorers @@ -102,7 +104,8 @@ async def _score_async( return [return_score] async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """Composite scorers do not support piecewise scoring. + """ + Composite scorers do not support piecewise scoring. Args: message_piece (MessagePiece): Unused. diff --git a/pyrit/score/true_false/true_false_inverter_scorer.py b/pyrit/score/true_false/true_false_inverter_scorer.py index 28b717307..6ee765e68 100644 --- a/pyrit/score/true_false/true_false_inverter_scorer.py +++ b/pyrit/score/true_false/true_false_inverter_scorer.py @@ -13,14 +13,14 @@ class TrueFalseInverterScorer(TrueFalseScorer): """A scorer that inverts a true false score.""" def __init__(self, *, scorer: TrueFalseScorer, validator: Optional[ScorerPromptValidator] = None) -> None: - """Initialize the TrueFalseInverterScorer. + """ + Initialize the TrueFalseInverterScorer. Args: scorer (TrueFalseScorer): The underlying true/false scorer whose results will be inverted. validator (Optional[ScorerPromptValidator]): Custom validator. Defaults to None. Note: This parameter is present for signature compatibility but is not used. """ - super().__init__(validator=ScorerPromptValidator()) if not isinstance(scorer, TrueFalseScorer): @@ -34,7 +34,8 @@ async def _score_async( objective: Optional[str] = None, role_filter: Optional[ChatMessageRole] = None, ) -> list[Score]: - """Scores the piece using the underlying true-false scorer and returns the inverted score. + """ + Scores the piece using the underlying true-false scorer and returns the inverted score. Args: message (Message): The message to score. @@ -69,7 +70,8 @@ async def _score_async( return [inv_score] async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: - """True False Inverter scorers do not support piecewise scoring. + """ + True False Inverter scorers do not support piecewise scoring. Args: message_piece (MessagePiece): Unused. diff --git a/pyrit/score/true_false/true_false_score_aggregator.py b/pyrit/score/true_false/true_false_score_aggregator.py index 7bb2dfa5d..88bce9e54 100644 --- a/pyrit/score/true_false/true_false_score_aggregator.py +++ b/pyrit/score/true_false/true_false_score_aggregator.py @@ -17,7 +17,8 @@ def _build_rationale(scores: List[Score], *, result: bool, true_msg: str, false_msg: str) -> tuple[str, str]: - """Build description and rationale for aggregated true/false scores. + """ + Build description and rationale for aggregated true/false scores. Args: scores: List of Score objects to aggregate. @@ -45,7 +46,8 @@ def _create_aggregator( true_msg: str, false_msg: str, ) -> TrueFalseAggregatorFunc: - """Create a True/False aggregator using a result function over boolean values. + """ + Create a True/False aggregator using a result function over boolean values. Args: name (str): Name of the aggregator variant. @@ -100,7 +102,8 @@ def _create_binary_aggregator( true_msg: str, false_msg: str, ) -> TrueFalseAggregatorFunc: - """Turn a binary Boolean operator (e.g. operator.and_) into an aggregation function. + """ + Turn a binary Boolean operator (e.g. operator.and_) into an aggregation function. Args: name (str): Name of the aggregator variant. @@ -121,7 +124,8 @@ def _create_binary_aggregator( # True/False aggregators (return list with single score) class TrueFalseScoreAggregator: - """Namespace for true/false score aggregators that return a single aggregated score. + """ + Namespace for true/false score aggregators that return a single aggregated score. All aggregators return a list containing one ScoreAggregatorResult that combines all input scores together, preserving all categories. diff --git a/pyrit/score/true_false/video_true_false_scorer.py b/pyrit/score/true_false/video_true_false_scorer.py index eea7b30df..8da85b193 100644 --- a/pyrit/score/true_false/video_true_false_scorer.py +++ b/pyrit/score/true_false/video_true_false_scorer.py @@ -59,7 +59,6 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op Returns: List containing a single aggregated score for the video. """ - # Get scores for all frames frame_scores = await self._score_frames_async(message_piece=message_piece, objective=objective) diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index ab7910a86..58d2cfc23 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -204,7 +204,6 @@ def initialize_pyrit( to execute directly. These provide type-safe, validated configuration with clear documentation. **memory_instance_kwargs (Optional[Any]): Additional keyword arguments to pass to the memory instance. """ - # Handle DuckDB deprecation before validation if memory_db_type == "DuckDB": logger.warning( diff --git a/pyrit/setup/initializers/scenarios/objective_list.py b/pyrit/setup/initializers/scenarios/objective_list.py index 95acd2b81..66b9956d7 100644 --- a/pyrit/setup/initializers/scenarios/objective_list.py +++ b/pyrit/setup/initializers/scenarios/objective_list.py @@ -25,7 +25,7 @@ def name(self) -> str: @property def execution_order(self) -> int: - "should be executed after most initializers" + """Should be executed after most initializers.""" return 10 @property diff --git a/pyrit/setup/initializers/scenarios/openai_objective_target.py b/pyrit/setup/initializers/scenarios/openai_objective_target.py index b4107fcb5..3473749aa 100644 --- a/pyrit/setup/initializers/scenarios/openai_objective_target.py +++ b/pyrit/setup/initializers/scenarios/openai_objective_target.py @@ -27,7 +27,7 @@ def name(self) -> str: @property def execution_order(self) -> int: - "should be executed after most initializers" + """Should be executed after most initializers.""" return 10 @property diff --git a/pyrit/show_versions.py b/pyrit/show_versions.py index c46810b34..09f4ce061 100644 --- a/pyrit/show_versions.py +++ b/pyrit/show_versions.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Utility methods to print system info for debugging. +""" +Utility methods to print system info for debugging. Adapted from :py:func:`pandas.show_versions` and :py:func:`sklearn.show_versions`. """ # noqa: RST304 @@ -11,12 +12,11 @@ def _get_sys_info(): - """System information. + """ + System information. - Returns - ------- - sys_info : dict - system and Python version information + Returns: + dict: system and Python version information """ python = sys.version.replace("\n", " ") @@ -30,15 +30,14 @@ def _get_sys_info(): def _get_deps_info(): - """Overview of the installed version of main dependencies. + """ + Overview of the installed version of main dependencies. This function does not import the modules to collect the version numbers but instead relies on standard Python package metadata. - Returns - ------- - deps_info: dict - version information on relevant Python libraries + Returns: + dict: version information on relevant Python libraries """ deps = sorted( [ diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index 6eca61149..32af9e26b 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -126,7 +126,6 @@ def start(self): """ Attempt to start the RPC server. If the server is already running, this method will throw an exception. """ - # Check if the server is already running by checking if the port is already in use. # If the port is already in use, throw an exception. if self._is_instance_running(): @@ -177,7 +176,6 @@ def stop_request(self): """ Request the RPC server to stop. This method is does not block while waiting for the server to stop. """ - logger.info("RPC server stopping") if self._server is not None: self._server.close()