diff --git a/pyproject.toml b/pyproject.toml index c63b6f1699..9c8adfb1a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,9 +173,8 @@ asyncio_mode = "auto" [tool.mypy] plugins = [] ignore_missing_imports = true -strict = false +strict = true follow_imports = "silent" -strict_optional = false disable_error_code = ["empty-body"] exclude = ["doc/code/", "pyrit/auxiliary_attacks/"] diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index 3c830050b6..a403d1aa37 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -62,9 +62,8 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats raise TypeError(f"Expected AttackResult, got {type(attack).__name__}: {attack!r}") outcome = attack.outcome - attack_type = ( - attack.get_attack_strategy_identifier().class_name if attack.get_attack_strategy_identifier() else "unknown" - ) + _strategy_id = attack.get_attack_strategy_identifier() + attack_type = _strategy_id.class_name if _strategy_id is not None else "unknown" if outcome == AttackOutcome.SUCCESS: overall_counts["successes"] += 1 diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 4149749e45..e56fd74315 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -296,7 +296,7 @@ def get_access_token_from_interactive_login(scope: str) -> str: """ try: token_provider = get_bearer_token_provider(InteractiveBrowserCredential(), scope) - return token_provider() + return str(token_provider()) except Exception as e: logger.error(f"Failed to obtain token for '{scope}': {e}") raise @@ -320,7 +320,7 @@ def get_azure_token_provider(scope: str) -> Callable[[], str]: >>> token = token_provider() # Get current token """ try: - return get_bearer_token_provider(DefaultAzureCredential(), scope) + return get_bearer_token_provider(DefaultAzureCredential(), scope) # type: ignore[no-any-return] except Exception as e: logger.error(f"Failed to obtain token provider for '{scope}': {e}") raise diff --git a/pyrit/auth/copilot_authenticator.py b/pyrit/auth/copilot_authenticator.py index ea85979fb6..bf20f47949 100644 --- a/pyrit/auth/copilot_authenticator.py +++ b/pyrit/auth/copilot_authenticator.py @@ -415,11 +415,15 @@ async def response_handler(response: Any) -> None: logger.info("Waiting for email input...") await page.wait_for_selector("#i0116", timeout=self._elements_timeout) + if self._username is None: + raise ValueError("Username is not set") await page.fill("#i0116", self._username) await page.click("#idSIButton9") logger.info("Waiting for password input...") await page.wait_for_selector("#i0118", timeout=self._elements_timeout) + if self._password is None: + raise ValueError("Password is not set") await page.fill("#i0118", self._password) await page.click("#idSIButton9") @@ -450,7 +454,7 @@ async def response_handler(response: Any) -> None: else: logger.error(f"Failed to retrieve bearer token within {self._token_capture_timeout} seconds.") - return bearer_token # type: ignore[no-any-return] + return bearer_token except Exception as e: logger.error("Failed to retrieve access token using Playwright.") diff --git a/pyrit/backend/routes/media.py b/pyrit/backend/routes/media.py index ee0835c715..e0a1daf682 100644 --- a/pyrit/backend/routes/media.py +++ b/pyrit/backend/routes/media.py @@ -87,7 +87,7 @@ async def serve_media_async( # Determine allowed directory from memory results_path try: memory = CentralMemory.get_memory_instance() - allowed_root = Path(memory.results_path).resolve() + allowed_root = Path(memory.results_path or "").resolve() except Exception as exc: raise HTTPException(status_code=500, detail="Memory not initialized; cannot determine results path.") from exc diff --git a/pyrit/backend/routes/version.py b/pyrit/backend/routes/version.py index f550084eb8..b59d176158 100644 --- a/pyrit/backend/routes/version.py +++ b/pyrit/backend/routes/version.py @@ -67,7 +67,7 @@ async def get_version_async(request: Request) -> VersionResponse: memory = CentralMemory.get_memory_instance() db_type = type(memory).__name__ db_name = None - if memory.engine.url.database: + if memory.engine is not None and memory.engine.url.database: db_name = memory.engine.url.database.split("?")[0] database_info = f"{db_type} ({db_name})" if db_name else f"{db_type} (None)" except Exception as e: diff --git a/pyrit/cli/_banner.py b/pyrit/cli/_banner.py index 859cb107ac..54d149abbf 100644 --- a/pyrit/cli/_banner.py +++ b/pyrit/cli/_banner.py @@ -566,11 +566,11 @@ def _render_line_with_segments( result: list[str] = [] current_role: Optional[ColorRole] = None for pos, ch in enumerate(line): - role = char_roles[pos] - if role != current_role: - color = _get_color(role, theme) if role else reset + char_role = char_roles[pos] + if char_role != current_role: + color = _get_color(char_role, theme) if char_role else reset result.append(color) - current_role = role + current_role = char_role result.append(ch) result.append(reset) return "".join(result) diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 5ff336c181..3b19b8a406 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -41,7 +41,7 @@ class termcolor: # type: ignore[no-redef] # noqa: N801 """Dummy termcolor fallback for colored printing if termcolor is not installed.""" @staticmethod - def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: ignore[type-arg] + def cprint(text: str, color: Optional[str] = None, attrs: Optional[list[Any]] = None) -> None: """Print text without color.""" print(text) @@ -187,7 +187,8 @@ def scenario_registry(self) -> ScenarioRegistry: raise RuntimeError( "FrontendCore not initialized. Call 'await context.initialize_async()' before accessing registries." ) - assert self._scenario_registry is not None + if self._scenario_registry is None: + raise ValueError("self._scenario_registry is not initialized") return self._scenario_registry @property @@ -202,7 +203,8 @@ def initializer_registry(self) -> InitializerRegistry: raise RuntimeError( "FrontendCore not initialized. Call 'await context.initialize_async()' before accessing registries." ) - assert self._initializer_registry is not None + if self._initializer_registry is None: + raise ValueError("self._initializer_registry is not initialized") return self._initializer_registry diff --git a/pyrit/common/data_url_converter.py b/pyrit/common/data_url_converter.py index 4fd6bb3b16..64d3ea97fb 100644 --- a/pyrit/common/data_url_converter.py +++ b/pyrit/common/data_url_converter.py @@ -23,7 +23,7 @@ async def convert_local_image_to_data_url(image_path: str) -> str: str: A string containing the MIME type and the base64-encoded data of the image, formatted as a data URL. """ ext = DataTypeSerializer.get_extension(image_path) - if ext.lower() not in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS: + if not ext or ext.lower() not in AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS: raise ValueError( f"Unsupported image format: {ext}. Supported formats are: {AZURE_OPENAI_GPT4O_SUPPORTED_IMAGE_FORMATS}" ) diff --git a/pyrit/common/display_response.py b/pyrit/common/display_response.py index 7341df8376..893896d413 100644 --- a/pyrit/common/display_response.py +++ b/pyrit/common/display_response.py @@ -1,52 +1,54 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import io -import logging - -from PIL import Image - -from pyrit.common.notebook_utils import is_in_ipython_session -from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece - -logger = logging.getLogger(__name__) - - -async def display_image_response(response_piece: MessagePiece) -> None: - """ - Display response images if running in notebook environment. - - Args: - response_piece (MessagePiece): The response piece to display. - """ - from pyrit.memory import CentralMemory - - memory = CentralMemory.get_memory_instance() - if ( - response_piece.response_error == "none" - and response_piece.converted_value_data_type == "image_path" - and is_in_ipython_session() - ): - image_location = response_piece.converted_value - - try: - image_bytes = await memory.results_storage_io.read_file(image_location) - except Exception as e: - if isinstance(memory.results_storage_io, AzureBlobStorageIO): - try: - # Fallback to reading from disk if the storage IO fails - image_bytes = await DiskStorageIO().read_file(image_location) - except Exception as exc: - logger.error(f"Failed to read image from {image_location}. Full exception: {str(exc)}") - return - else: - logger.error(f"Failed to read image from {image_location}. Full exception: {str(e)}") - return - - image_stream = io.BytesIO(image_bytes) - image = Image.open(image_stream) - - # Jupyter built-in display function only works in notebooks. - display(image) # type: ignore[name-defined] # noqa: F821 - if response_piece.response_error == "blocked": - logger.info("---\nContent blocked, cannot show a response.\n---") +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import io +import logging + +from PIL import Image + +from pyrit.common.notebook_utils import is_in_ipython_session +from pyrit.models import AzureBlobStorageIO, DiskStorageIO, MessagePiece + +logger = logging.getLogger(__name__) + + +async def display_image_response(response_piece: MessagePiece) -> None: + """ + Display response images if running in notebook environment. + + Args: + response_piece (MessagePiece): The response piece to display. + """ + from pyrit.memory import CentralMemory + + memory = CentralMemory.get_memory_instance() + if ( + response_piece.response_error == "none" + and response_piece.converted_value_data_type == "image_path" + and is_in_ipython_session() + ): + image_location = response_piece.converted_value + + try: + if memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") + image_bytes = await memory.results_storage_io.read_file(image_location) + except Exception as e: + if isinstance(memory.results_storage_io, AzureBlobStorageIO): + try: + # Fallback to reading from disk if the storage IO fails + image_bytes = await DiskStorageIO().read_file(image_location) + except Exception as exc: + logger.error(f"Failed to read image from {image_location}. Full exception: {str(exc)}") + return + else: + logger.error(f"Failed to read image from {image_location}. Full exception: {str(e)}") + return + + image_stream = io.BytesIO(image_bytes) + image = Image.open(image_stream) + + # Jupyter built-in display function only works in notebooks. + display(image) # type: ignore[name-defined] # noqa: F821 + if response_piece.response_error == "blocked": + logger.info("---\nContent blocked, cannot show a response.\n---") diff --git a/pyrit/common/net_utility.py b/pyrit/common/net_utility.py index 2ecff147a5..cb92e40952 100644 --- a/pyrit/common/net_utility.py +++ b/pyrit/common/net_utility.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any, Literal, Optional, overload +from typing import Any, Literal, Optional, cast, overload from urllib.parse import parse_qs, urlparse, urlunparse import httpx @@ -10,18 +10,18 @@ @overload def get_httpx_client( - use_async: Literal[True], debug: bool = False, **httpx_client_kwargs: Optional[Any] + use_async: Literal[True], debug: bool = False, **httpx_client_kwargs: Any ) -> httpx.AsyncClient: ... @overload def get_httpx_client( - use_async: Literal[False] = False, debug: bool = False, **httpx_client_kwargs: Optional[Any] + use_async: Literal[False] = False, debug: bool = False, **httpx_client_kwargs: Any ) -> httpx.Client: ... def get_httpx_client( - use_async: bool = False, debug: bool = False, **httpx_client_kwargs: Optional[Any] + use_async: bool = False, debug: bool = False, **httpx_client_kwargs: Any ) -> httpx.Client | httpx.AsyncClient: """ Get the httpx client for making requests. @@ -32,10 +32,10 @@ def get_httpx_client( client_class = httpx.AsyncClient if use_async else httpx.Client proxy = "http://localhost:8080" if debug else None - proxy = httpx_client_kwargs.pop("proxy", proxy) - verify_certs = httpx_client_kwargs.pop("verify", not debug) + proxy = cast(Optional[str], httpx_client_kwargs.pop("proxy", proxy)) + verify_certs = cast(bool, httpx_client_kwargs.pop("verify", not debug)) # fun notes; httpx default is 5 seconds, httpclient is 100, urllib in indefinite - timeout = httpx_client_kwargs.pop("timeout", 60.0) + timeout = cast(float, httpx_client_kwargs.pop("timeout", 60.0)) return client_class(proxy=proxy, verify=verify_certs, timeout=timeout, **httpx_client_kwargs) @@ -92,7 +92,7 @@ async def make_request_and_raise_if_error_async( request_body: Optional[dict[str, object]] = None, files: Optional[dict[str, tuple[str, bytes, str]]] = None, headers: Optional[dict[str, str]] = None, - **httpx_client_kwargs: Optional[Any], + **httpx_client_kwargs: Any, ) -> httpx.Response: """ Make a request and raise an exception if it fails. diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py index da245b6875..5d9a7b96ef 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py @@ -232,9 +232,13 @@ async def _fetch_and_save_image_async(self, image_url: str, behavior_id: str) -> serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") # Return existing path if image already exists for this BehaviorID - serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}") + results_path = serializer._memory.results_path + results_storage_io = serializer._memory.results_storage_io + if not results_path or results_storage_io is None: + raise RuntimeError("[HarmBench-Multimodal] Serializer memory is not properly configured: results_path and results_storage_io must be set.") + serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}") try: - if await serializer._memory.results_storage_io.path_exists(serializer.value): + if await results_storage_io.path_exists(serializer.value): return serializer.value except Exception as e: logger.warning(f"[HarmBench-Multimodal] Failed to check if image for {behavior_id} exists in cache: {e}") diff --git a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py index 472d43022f..22e1860df3 100644 --- a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py @@ -171,6 +171,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: group_id = uuid.uuid4() try: + if image_url is None or text is None: + continue local_image_path = await self._fetch_and_save_image_async(image_url, str(group_id)) # Create text prompt (sequence=0, sent first) @@ -179,13 +181,13 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: data_type="text", name="ML-VLSU Text", dataset_name=self.dataset_name, - harm_categories=[combined_category], + harm_categories=[combined_category or ""], description="Text component of ML-VLSU multimodal prompt.", source=self.source, prompt_group_id=group_id, sequence=0, metadata={ - "category": combined_category, + "category": combined_category or "", "text_grade": text_grade, "image_grade": image_grade, "combined_grade": combined_grade, @@ -198,13 +200,13 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: data_type="image_path", name="ML-VLSU Image", dataset_name=self.dataset_name, - harm_categories=[combined_category], + harm_categories=[combined_category or ""], description="Image component of ML-VLSU multimodal prompt.", source=self.source, prompt_group_id=group_id, sequence=1, metadata={ - "category": combined_category, + "category": combined_category or "", "text_grade": text_grade, "image_grade": image_grade, "combined_grade": combined_grade, @@ -245,9 +247,13 @@ async def _fetch_and_save_image_async(self, image_url: str, group_id: str) -> st serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") # Return existing path if image already exists - serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}") + results_path = serializer._memory.results_path + results_storage_io = serializer._memory.results_storage_io + if not results_path or results_storage_io is None: + raise RuntimeError("[ML-VLSU] Serializer memory is not properly configured.") + serializer.value = str(results_path + serializer.data_sub_directory + f"/{filename}") try: - if await serializer._memory.results_storage_io.path_exists(serializer.value): + if await results_storage_io.path_exists(serializer.value): return serializer.value except Exception as e: logger.warning(f"[ML-VLSU] Failed to check if image for {group_id} exists in cache: {e}") diff --git a/pyrit/embedding/openai_text_embedding.py b/pyrit/embedding/openai_text_embedding.py index f6b51a10b8..66b00280bc 100644 --- a/pyrit/embedding/openai_text_embedding.py +++ b/pyrit/embedding/openai_text_embedding.py @@ -63,7 +63,7 @@ def __init__( # Create async client - type: ignore needed because get_required_value returns str # but api_key parameter accepts str | Callable[[], str | Awaitable[str]] self._async_client = AsyncOpenAI( - api_key=api_key, # type: ignore[arg-type] + api_key=api_key, base_url=endpoint, ) diff --git a/pyrit/executor/attack/core/attack_parameters.py b/pyrit/executor/attack/core/attack_parameters.py index 95635cde3b..53bd34f6f5 100644 --- a/pyrit/executor/attack/core/attack_parameters.py +++ b/pyrit/executor/attack/core/attack_parameters.py @@ -123,7 +123,8 @@ async def from_seed_group_async( seed_group.validate() # SeedAttackGroup validates in __init__ that objective is set - assert seed_group.objective is not None + if seed_group.objective is None: + raise ValueError("seed_group.objective is not initialized") # Build params dict, only including fields this class accepts params: dict[str, Any] = {} diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index ece67750ef..70b57332c7 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -166,7 +166,7 @@ class TAPAttackResult(AttackResult): @property def tree_visualization(self) -> Optional[Tree]: """Get the tree visualization from metadata.""" - return cast("Optional[Tree]", self.metadata.get("tree_visualization", None)) + return self.metadata.get("tree_visualization", None) @tree_visualization.setter def tree_visualization(self, value: Tree) -> None: @@ -545,6 +545,8 @@ async def _send_prompt_to_target_async(self, prompt: str) -> Message: ) # Store the last response text for reference + if response is None: + raise ValueError("Response was None") response_piece = response.get_piece() self.last_response = response_piece.converted_value logger.debug(f"Node {self.node_id}: Received response from target") @@ -601,6 +603,8 @@ async def _send_initial_prompt_to_target_async(self) -> Message: ) # Store the last response text for reference + if response is None: + raise ValueError("Response was None") response_piece = response.get_piece() self.last_response = response_piece.converted_value logger.debug(f"Node {self.node_id}: Received response from target") @@ -1111,6 +1115,8 @@ async def _send_to_adversarial_chat_async(self, prompt_text: str) -> str: attack_identifier=self._attack_id, ) + if response is None: + raise ValueError("Response was None") return response.get_value() def _parse_red_teaming_response(self, red_teaming_response: str) -> str: @@ -1359,6 +1365,8 @@ def __init__( "TAP attack requires a FloatScaleThresholdScorer for objective_scorer. " "Please wrap your scorer in FloatScaleThresholdScorer with an appropriate threshold." ) + if objective_scorer is None: + raise ValueError("objective_scorer is required") tap_scoring_config = TAPAttackScoringConfig( objective_scorer=objective_scorer, refusal_scorer=attack_scoring_config.refusal_scorer, diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index e50446bb38..5946ce985c 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -487,9 +487,8 @@ async def _get_summary_markdown_async(self, result: AttackResult) -> list[str]: markdown_lines.append("|-------|-------|") markdown_lines.append(f"| **Objective** | {result.objective} |") - attack_type = ( - result.get_attack_strategy_identifier().class_name if result.get_attack_strategy_identifier() else "Unknown" - ) + _strategy_id = result.get_attack_strategy_identifier() + attack_type = _strategy_id.class_name if _strategy_id is not None else "Unknown" markdown_lines.append(f"| **Attack Type** | `{attack_type}` |") markdown_lines.append(f"| **Conversation ID** | `{result.conversation_id}` |") diff --git a/pyrit/executor/attack/single_turn/context_compliance.py b/pyrit/executor/attack/single_turn/context_compliance.py index d03ab2a41f..55e4a82e02 100644 --- a/pyrit/executor/attack/single_turn/context_compliance.py +++ b/pyrit/executor/attack/single_turn/context_compliance.py @@ -238,6 +238,8 @@ async def _get_objective_as_benign_question_async( labels=context.memory_labels, ) + if response is None: + raise ValueError("Response was None") return response.get_value() async def _get_benign_question_answer_async( @@ -265,6 +267,8 @@ async def _get_benign_question_answer_async( labels=context.memory_labels, ) + if response is None: + raise ValueError("Response was None") return response.get_value() async def _get_objective_as_question_async(self, *, objective: str, context: SingleTurnAttackContext[Any]) -> str: @@ -290,6 +294,8 @@ async def _get_objective_as_question_async(self, *, objective: str, context: Sin labels=context.memory_labels, ) + if response is None: + raise ValueError("Response was None") return response.get_value() def _construct_assistant_response(self, *, benign_answer: str, objective_question: str) -> str: diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index 208c4040d7..7400719054 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -358,7 +358,8 @@ async def _extract_knowledge_graph_async(self, *, context: AnecdoctorContext) -> RuntimeError: If knowledge graph extraction fails. """ # Processing model is guaranteed to exist when this method is called - assert self._processing_model is not None + if self._processing_model is None: + raise ValueError("self._processing_model is not initialized") self._logger.debug("Extracting knowledge graph from evaluation data") diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 7021c0d6ad..ea33c36e80 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -1020,8 +1020,11 @@ def _create_normalizer_requests(self, prompts: list[str]) -> list[NormalizerRequ for prompt in prompts: seed_group = SeedGroup(seeds=[SeedPrompt(value=prompt, data_type="text")]) + _msg = seed_group.next_message + if _msg is None: + raise ValueError("No message in seed group") request = NormalizerRequest( - message=seed_group.next_message, + message=_msg, request_converter_configurations=self._request_converters, response_converter_configurations=self._response_converters, ) diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index 2dc021b497..a053e35a30 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -339,6 +339,8 @@ async def _setup_attack_async(self, *, context: XPIAContext) -> str: conversation_id=context.attack_setup_target_conversation_id, ) + if setup_response is None: + raise ValueError("Setup response was None") setup_response_text = setup_response.get_value() self._logger.info(f'Received the following response from the prompt target: "{setup_response_text}"') @@ -357,7 +359,11 @@ async def _execute_processing_async(self, *, context: XPIAContext) -> str: Returns: str: The response from the processing target. """ + if context.processing_callback is None: + raise ValueError("processing_callback is not set") processing_response = await context.processing_callback() + if self._memory is None: + raise ValueError("Memory not initialized") self._memory.add_message_to_memory( request=Message( message_pieces=[ @@ -560,7 +566,8 @@ async def _setup_async(self, *, context: XPIAContext) -> None: # Create the processing callback using the test context async def process_async() -> str: # processing_prompt is validated to be non-None in _validate_context - assert context.processing_prompt is not None + if context.processing_prompt is None: + raise ValueError("context.processing_prompt is not initialized") response = await self._prompt_normalizer.send_prompt_async( message=context.processing_prompt, target=self._processing_target, @@ -571,6 +578,8 @@ async def process_async() -> str: conversation_id=context.processing_conversation_id, ) + if response is None: + raise ValueError("Response was None") return response.get_value() # Set the processing callback on the context diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index c9f349c0d9..a04aa74129 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -143,7 +143,7 @@ def _refresh_token_if_needed(self) -> None: """ Refresh the access token if it is close to expiry (within 5 minutes). """ - if datetime.now(timezone.utc) >= datetime.fromtimestamp(self._auth_token_expiry, tz=timezone.utc) - timedelta( + if self._auth_token_expiry is not None and datetime.now(timezone.utc) >= datetime.fromtimestamp(float(self._auth_token_expiry), tz=timezone.utc) - timedelta( minutes=5 ): logger.info("Refreshing Microsoft Entra ID access token...") @@ -201,6 +201,8 @@ def provide_token(_dialect: Any, _conn_rec: Any, cargs: list[Any], cparams: dict cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "") # encode the token + if self._auth_token is None: + raise RuntimeError("Azure auth token is not initialized") azure_token = self._auth_token.token azure_token_bytes = azure_token.encode("utf-16-le") packed_azure_token = struct.pack(f" None: """ try: # Using the 'checkfirst=True' parameter to avoid attempting to recreate existing tables + if self.engine is None: + raise RuntimeError("Engine is not initialized") Base.metadata.create_all(self.engine, checkfirst=True) except Exception as e: logger.exception(f"Error during table creation: {e}") @@ -791,6 +795,10 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict def reset_database(self) -> None: """Drop and recreate existing tables.""" # Drop all existing tables + if self.engine is None: + + raise RuntimeError("Engine is not initialized") + Base.metadata.drop_all(self.engine) # Recreate the tables Base.metadata.create_all(self.engine, checkfirst=True) diff --git a/pyrit/memory/central_memory.py b/pyrit/memory/central_memory.py index a933e73107..0ef8afe372 100644 --- a/pyrit/memory/central_memory.py +++ b/pyrit/memory/central_memory.py @@ -1,4 +1,5 @@ # Copyright (c) Microsoft Corporation. +from typing import Optional # Licensed under the MIT license. import logging @@ -14,7 +15,7 @@ class CentralMemory: The provided memory instance will be reused for future calls. """ - _memory_instance: MemoryInterface = None + _memory_instance: Optional[MemoryInterface] = None @classmethod def set_memory_instance(cls, passed_memory: MemoryInterface) -> None: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 90322ebec4..fc57d6a366 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -69,10 +69,10 @@ class MemoryInterface(abc.ABC): such as files, databases, or cloud storage services. """ - memory_embedding: MemoryEmbedding = None - results_storage_io: StorageIO = None - results_path: str = None - engine: Engine = None + memory_embedding: Optional[MemoryEmbedding] = None + results_storage_io: Optional[StorageIO] = None + results_path: Optional[str] = None + engine: Optional[Engine] = None def __init__(self, embedding_model: Optional[Any] = None) -> None: """ @@ -1007,7 +1007,7 @@ async def _serialize_seed_value(self, prompt: Seed) -> str: audio_bytes = await serializer.read_data() await serializer.save_data(data=audio_bytes) serialized_prompt_value = str(serializer.value) - return serialized_prompt_value + return serialized_prompt_value or "" async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: Optional[str] = None) -> None: """ @@ -1044,7 +1044,7 @@ async def add_seeds_to_memory_async(self, *, seeds: Sequence[Seed], added_by: Op await prompt.set_sha256_value_async() - if not self.get_seeds(value_sha256=[prompt.value_sha256], dataset_name=prompt.dataset_name): + if prompt.value_sha256 and not self.get_seeds(value_sha256=[prompt.value_sha256], dataset_name=prompt.dataset_name): entries.append(SeedEntry(entry=prompt)) self._insert_entries(entries=entries) @@ -1724,7 +1724,8 @@ def get_scenario_results( def print_schema(self) -> None: """Print the schema of all tables in the database.""" metadata = MetaData() - metadata.reflect(bind=self.engine) + if self.engine: + metadata.reflect(bind=self.engine) for table_name in metadata.tables: table = metadata.tables[table_name] diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index e9c83b9300..5077415942 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -395,7 +395,7 @@ def __init__(self, *, entry: Score): self.score_type = entry.score_type self.score_category = entry.score_category self.score_rationale = entry.score_rationale - self.score_metadata = entry.score_metadata + self.score_metadata = entry.score_metadata # type: ignore[assignment] # Normalize to ComponentIdentifier (handles dict with deprecation warning) then convert to dict for JSON storage normalized_scorer = ComponentIdentifier.normalize(entry.scorer_class_identifier) self.scorer_class_identifier = normalized_scorer.to_dict(max_value_length=MAX_IDENTIFIER_VALUE_LENGTH) @@ -429,7 +429,7 @@ def get_score(self) -> Score: score_category=self.score_category, score_rationale=self.score_rationale, score_metadata=self.score_metadata, - scorer_class_identifier=scorer_identifier, + scorer_class_identifier=scorer_identifier, # type: ignore[arg-type] message_piece_id=self.prompt_request_response_id, timestamp=_ensure_utc(self.timestamp), objective=self.objective, @@ -584,7 +584,7 @@ def __init__(self, *, entry: Seed): self.source = entry.source self.date_added = entry.date_added self.added_by = entry.added_by - self.prompt_metadata = entry.metadata + self.prompt_metadata = entry.metadata # type: ignore[assignment] self.prompt_group_id = entry.prompt_group_id self.seed_type = seed_type # Deprecated: kept for backward compatibility with existing databases @@ -594,11 +594,11 @@ def __init__(self, *, entry: Seed): if isinstance(entry, SeedPrompt): self.parameters = list(entry.parameters) if entry.parameters else None self.sequence = entry.sequence - self.role = entry.role + self.role = entry.role # type: ignore[assignment] else: self.parameters = None self.sequence = None - self.role = None + self.role = None # type: ignore[assignment] def get_seed(self) -> Seed: """ @@ -683,7 +683,7 @@ def get_seed(self) -> Seed: metadata=self.prompt_metadata, parameters=self.parameters, prompt_group_id=self.prompt_group_id, - sequence=self.sequence, + sequence=self.sequence or 0, role=self.role, ) @@ -1033,7 +1033,7 @@ def get_scenario_result(self) -> ScenarioResult: scenario_identifier=scenario_identifier, objective_target_identifier=target_identifier, attack_results=attack_results, - objective_scorer_identifier=scorer_identifier, + objective_scorer_identifier=scorer_identifier, # type: ignore[arg-type] scenario_run_state=self.scenario_run_state, labels=self.labels, number_tries=self.number_tries, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index d59a6571d2..f70836221c 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -111,6 +111,8 @@ def _create_tables_if_not_exist(self) -> None: """ try: # Using the 'checkfirst=True' parameter to avoid attempting to recreate existing tables + if self.engine is None: + raise RuntimeError("Engine is not initialized") Base.metadata.create_all(self.engine, checkfirst=True) except Exception as e: logger.exception(f"Error during table creation: {e}") @@ -337,7 +339,15 @@ def reset_database(self) -> None: """ Drop and recreates all tables in the database. """ + if self.engine is None: + + raise RuntimeError("Engine is not initialized") + Base.metadata.drop_all(self.engine) + if self.engine is None: + + raise RuntimeError("Engine is not initialized") + Base.metadata.create_all(self.engine) def dispose_engine(self) -> None: diff --git a/pyrit/message_normalizer/chat_message_normalizer.py b/pyrit/message_normalizer/chat_message_normalizer.py index 0ebfa37946..2fa3bfc0a2 100644 --- a/pyrit/message_normalizer/chat_message_normalizer.py +++ b/pyrit/message_normalizer/chat_message_normalizer.py @@ -164,7 +164,7 @@ async def _convert_audio_to_input_audio(self, audio_path: str) -> dict[str, Any] ValueError: If the audio format is not supported. FileNotFoundError: If the audio file does not exist. """ - ext = DataTypeSerializer.get_extension(audio_path).lower() + ext = (DataTypeSerializer.get_extension(audio_path) or "").lower() if ext not in SUPPORTED_AUDIO_FORMATS: raise ValueError( f"Unsupported audio format: {ext}. Supported formats are: {list(SUPPORTED_AUDIO_FORMATS.keys())}" diff --git a/pyrit/models/data_type_serializer.py b/pyrit/models/data_type_serializer.py index c2004160fb..eafbebebf9 100644 --- a/pyrit/models/data_type_serializer.py +++ b/pyrit/models/data_type_serializer.py @@ -96,7 +96,7 @@ class DataTypeSerializer(abc.ABC): data_sub_directory: str file_extension: str - _file_path: Union[Path, str] = None + _file_path: Optional[Union[Path, str]] = None @property def _memory(self) -> MemoryInterface: @@ -118,7 +118,7 @@ def _get_storage_io(self) -> StorageIO: if self._is_azure_storage_url(self.value): # Scenarios where a user utilizes an in-memory DuckDB but also needs to interact # with an Azure Storage Account, ex., XPIAWorkflow. - return self._memory.results_storage_io + return self._memory.results_storage_io or DiskStorageIO() return DiskStorageIO() @abc.abstractmethod @@ -141,10 +141,12 @@ async def save_data(self, data: bytes, output_filename: Optional[str] = None) -> """ file_path = await self.get_data_filename(file_name=output_filename) + if self._memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") await self._memory.results_storage_io.write_file(file_path, data) self.value = str(file_path) - async def save_b64_image(self, data: str | bytes, output_filename: str = None) -> None: + async def save_b64_image(self, data: str | bytes, output_filename: Optional[str] = None) -> None: """ Save a base64-encoded image to storage. @@ -155,6 +157,8 @@ async def save_b64_image(self, data: str | bytes, output_filename: str = None) - """ file_path = await self.get_data_filename(file_name=output_filename) image_bytes = base64.b64decode(data) + if self._memory.results_storage_io is None: + raise RuntimeError("Storage IO not initialized") await self._memory.results_storage_io.write_file(file_path, image_bytes) self.value = str(file_path) @@ -190,6 +194,8 @@ async def save_formatted_audio( async with aiofiles.open(local_temp_path, "rb") as f: audio_data = await f.read() + if self._memory.results_storage_io is None: + raise ValueError("self._memory.results_storage_io is not initialized") await self._memory.results_storage_io.write_file(file_path, audio_data) os.remove(local_temp_path) @@ -253,7 +259,7 @@ async def get_sha256(self) -> str: ValueError: If in-memory data cannot be converted to bytes. """ - input_bytes: bytes = None + input_bytes: Optional[bytes] = None if self.data_on_disk(): storage_io = self._get_storage_io() @@ -297,7 +303,11 @@ async def get_data_filename(self, file_name: Optional[str] = None) -> Union[Path raise RuntimeError("Data sub directory not set") ticks = int(time.time() * 1_000_000) - results_path = self._memory.results_path + if self._memory.results_path: + results_path = str(self._memory.results_path) + else: + from pyrit.common.path import DB_DATA_PATH + results_path = str(DB_DATA_PATH) file_name = file_name if file_name else str(ticks) if self._is_azure_storage_url(results_path): @@ -305,6 +315,8 @@ async def get_data_filename(self, file_name: Optional[str] = None) -> Union[Path self._file_path = full_data_directory_path + f"/{file_name}.{self.file_extension}" else: full_data_directory_path = results_path + self.data_sub_directory + if self._memory.results_storage_io is None: + raise ValueError("self._memory.results_storage_io is not initialized") await self._memory.results_storage_io.create_directory_if_not_exists(Path(full_data_directory_path)) self._file_path = Path(full_data_directory_path, f"{file_name}.{self.file_extension}") diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 083728aa0d..3164b3438f 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -297,7 +297,7 @@ def set_piece_not_in_database(self) -> None: This is needed when we're scoring prompts or other things that have not been sent by PyRIT """ - self.id = None + self.id = None # type: ignore[assignment] def to_dict(self) -> dict[str, object]: """ diff --git a/pyrit/models/seeds/seed_attack_group.py b/pyrit/models/seeds/seed_attack_group.py index b994f5108e..99438ee378 100644 --- a/pyrit/models/seeds/seed_attack_group.py +++ b/pyrit/models/seeds/seed_attack_group.py @@ -87,5 +87,6 @@ def objective(self) -> SeedObjective: """ obj = self._get_objective() - assert obj is not None, "SeedAttackGroup should always have an objective" + if obj is None: + raise ValueError("SeedAttackGroup should always have an objective") return obj diff --git a/pyrit/models/seeds/seed_dataset.py b/pyrit/models/seeds/seed_dataset.py index 8bfe6edefc..f7026fef31 100644 --- a/pyrit/models/seeds/seed_dataset.py +++ b/pyrit/models/seeds/seed_dataset.py @@ -171,14 +171,14 @@ def __init__( } if effective_type == "simulated_conversation": - self.seeds.append( - SeedSimulatedConversation( - **base_params, - num_turns=p.get("num_turns", 3), - adversarial_chat_system_prompt_path=p.get("adversarial_chat_system_prompt_path"), - simulated_target_system_prompt_path=p.get("simulated_target_system_prompt_path"), - ) - ) + _adv_path = p.get("adversarial_chat_system_prompt_path") + _sim_path = p.get("simulated_target_system_prompt_path") + _sc_kwargs: dict[str, Any] = {**base_params, "num_turns": p.get("num_turns", 3)} + if _adv_path is not None: + _sc_kwargs["adversarial_chat_system_prompt_path"] = str(_adv_path) + if _sim_path is not None: + _sc_kwargs["simulated_target_system_prompt_path"] = str(_sim_path) + self.seeds.append(SeedSimulatedConversation(**_sc_kwargs)) elif effective_type == "objective": # SeedObjective inherits data_type="text" from base Seed property base_params["value"] = p["value"] diff --git a/pyrit/models/seeds/seed_prompt.py b/pyrit/models/seeds/seed_prompt.py index b507cf3173..ab75132a0f 100644 --- a/pyrit/models/seeds/seed_prompt.py +++ b/pyrit/models/seeds/seed_prompt.py @@ -35,7 +35,7 @@ class SeedPrompt(Seed): # The type of data this prompt represents (e.g., text, image_path, audio_path, video_path) # This field shadows the base class property to allow per-prompt data types - data_type: Optional[PromptDataType] = None + data_type: Optional[PromptDataType] = None # type: ignore[assignment] # Role of the prompt in a conversation (e.g., "user", "assistant") role: Optional[ChatMessageRole] = None diff --git a/pyrit/models/storage_io.py b/pyrit/models/storage_io.py index 7d00dc1570..8d288d3b1c 100644 --- a/pyrit/models/storage_io.py +++ b/pyrit/models/storage_io.py @@ -292,7 +292,7 @@ async def read_file(self, path: Union[Path, str]) -> bytes: # Download the blob blob_stream = await blob_client.download_blob() - return await blob_stream.readall() + return bytes(await blob_stream.readall()) except Exception as exc: logger.exception(f"Failed to read file at {blob_name}: {exc}") @@ -363,7 +363,7 @@ async def is_file(self, path: Union[Path, str]) -> bool: _, blob_name = self.parse_blob_url(str(path)) blob_client = self._client_async.get_blob_client(blob=blob_name) blob_properties = await blob_client.get_blob_properties() - return blob_properties.size > 0 + return bool(blob_properties.size > 0) except ResourceNotFoundError: return False finally: diff --git a/pyrit/prompt_converter/add_image_text_converter.py b/pyrit/prompt_converter/add_image_text_converter.py index 8cbf4d8671..b5fb4db89f 100644 --- a/pyrit/prompt_converter/add_image_text_converter.py +++ b/pyrit/prompt_converter/add_image_text_converter.py @@ -168,7 +168,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text updated_img = self._add_text_to_image(text=prompt) image_bytes = BytesIO() - mime_type = img_serializer.get_mime_type(self._img_to_add) + mime_type = img_serializer.get_mime_type(self._img_to_add) or "image/png" image_type = mime_type.split("/")[-1] updated_img.save(image_bytes, format=image_type) image_str = base64.b64encode(image_bytes.getvalue()) diff --git a/pyrit/prompt_converter/add_text_image_converter.py b/pyrit/prompt_converter/add_text_image_converter.py index 91fd265e57..ea3236b403 100644 --- a/pyrit/prompt_converter/add_text_image_converter.py +++ b/pyrit/prompt_converter/add_text_image_converter.py @@ -165,7 +165,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "imag updated_img = self._add_text_to_image(image=original_img) image_bytes = BytesIO() - mime_type = img_serializer.get_mime_type(prompt) + mime_type = img_serializer.get_mime_type(prompt) or "image/png" image_type = mime_type.split("/")[-1] updated_img.save(image_bytes, format=image_type) image_str = base64.b64encode(image_bytes.getvalue()).decode("utf-8") diff --git a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py index 7c5fdad176..37ca3f4ec1 100644 --- a/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py +++ b/pyrit/prompt_converter/azure_speech_text_to_audio_converter.py @@ -181,4 +181,4 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text except Exception as e: logger.error("Failed to convert prompt to audio: %s", str(e)) raise - return ConverterResult(output_text=audio_serializer_file, output_type="audio_path") + return ConverterResult(output_text=audio_serializer_file or "", output_type="audio_path") diff --git a/pyrit/prompt_converter/codechameleon_converter.py b/pyrit/prompt_converter/codechameleon_converter.py index 99454946df..43bb1e4ab8 100644 --- a/pyrit/prompt_converter/codechameleon_converter.py +++ b/pyrit/prompt_converter/codechameleon_converter.py @@ -132,7 +132,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text if not self.input_supported(input_type): raise ValueError("Input type not supported") - encoded_prompt = str(self.encrypt_function(prompt)) if self.encrypt_function else prompt + encoded_prompt = str(self.encrypt_function(prompt)) if self.encrypt_function is not None else prompt seed_prompt = SeedPrompt.from_yaml_file( pathlib.Path(CONVERTER_SEED_PROMPT_PATH) / "codechameleon_converter.yaml" diff --git a/pyrit/prompt_converter/denylist_converter.py b/pyrit/prompt_converter/denylist_converter.py index a9672e3718..916a961952 100644 --- a/pyrit/prompt_converter/denylist_converter.py +++ b/pyrit/prompt_converter/denylist_converter.py @@ -28,7 +28,7 @@ def __init__( *, converter_target: PromptChatTarget = REQUIRED_VALUE, # type: ignore[assignment] system_prompt_template: Optional[SeedPrompt] = None, - denylist: list[str] = None, + denylist: Optional[list[str]] = None, ): """ Initialize the converter with a target, an optional system prompt template, and a denylist. diff --git a/pyrit/prompt_converter/template_segment_converter.py b/pyrit/prompt_converter/template_segment_converter.py index 8520436471..07ab83e164 100644 --- a/pyrit/prompt_converter/template_segment_converter.py +++ b/pyrit/prompt_converter/template_segment_converter.py @@ -51,18 +51,18 @@ def __init__( ) ) - self._number_parameters = len(self.prompt_template.parameters) + self._number_parameters = len(self.prompt_template.parameters or []) if self._number_parameters < 2: raise ValueError( - f"Template must have at least two parameters, but found {len(self.prompt_template.parameters)}. " + f"Template must have at least two parameters, but found {len(self.prompt_template.parameters or [])}. " f"Template parameters: {self.prompt_template.parameters}" ) # Validate all parameters exist in the template value by attempting to render with empty values try: # Create a dict with empty values for all parameters - empty_values = dict.fromkeys(self.prompt_template.parameters, "") + empty_values = dict.fromkeys(self.prompt_template.parameters or [], "") # This will raise ValueError if any parameter is missing self.prompt_template.render_template_value(**empty_values) except ValueError as e: @@ -107,7 +107,7 @@ async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text segments = self._split_prompt_into_segments(prompt) filled_template = self.prompt_template.render_template_value( - **dict(zip(self.prompt_template.parameters, segments, strict=False)) + **dict(zip(self.prompt_template.parameters or [], segments, strict=False)) ) return ConverterResult(output_text=filled_template, output_type="text") diff --git a/pyrit/prompt_normalizer/normalizer_request.py b/pyrit/prompt_normalizer/normalizer_request.py index 30869a09b2..020d55429c 100644 --- a/pyrit/prompt_normalizer/normalizer_request.py +++ b/pyrit/prompt_normalizer/normalizer_request.py @@ -25,8 +25,8 @@ def __init__( self, *, message: Message, - request_converter_configurations: list[PromptConverterConfiguration] = None, - response_converter_configurations: list[PromptConverterConfiguration] = None, + request_converter_configurations: Optional[list[PromptConverterConfiguration]] = None, + response_converter_configurations: Optional[list[PromptConverterConfiguration]] = None, conversation_id: Optional[str] = None, ): """ diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index b730a58669..5ed281cfe9 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -32,7 +32,13 @@ class PromptNormalizer: Handles normalization and processing of prompts before they are sent to targets. """ - _memory: MemoryInterface = None + _memory: Optional[MemoryInterface] = None + + @property + def memory(self) -> MemoryInterface: + if self._memory is None: + raise ValueError("Memory is not initialized") + return self._memory def __init__(self, start_token: str = "⟪", end_token: str = "⟫") -> None: """ @@ -55,7 +61,7 @@ async def send_prompt_async( response_converter_configurations: list[PromptConverterConfiguration] | None = None, labels: Optional[dict[str, str]] = None, attack_identifier: Optional[ComponentIdentifier] = None, - ) -> Message: + ) -> Optional[Message]: """ Send a single request to a target. @@ -105,10 +111,10 @@ async def send_prompt_async( try: responses = await target.send_prompt_async(message=request) - self._memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request) except EmptyResponseException: # Empty responses are retried, but we don't want them to stop execution - self._memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request) responses = [ construct_response_from_request( @@ -121,7 +127,7 @@ async def send_prompt_async( except Exception as ex: # Ensure request to memory before processing exception - self._memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request) error_response = construct_response_from_request( request=request.message_pieces[0], @@ -131,7 +137,7 @@ async def send_prompt_async( ) await self._calc_hash(request=error_response) - self._memory.add_message_to_memory(request=error_response) + self.memory.add_message_to_memory(request=error_response) cid = request.message_pieces[0].conversation_id if request and request.message_pieces else None raise Exception(f"Error sending prompt with conversation ID: {cid}") from ex @@ -147,7 +153,7 @@ async def send_prompt_async( if is_last: await self.convert_values(converter_configurations=response_converter_configurations, message=resp) await self._calc_hash(request=resp) - self._memory.add_message_to_memory(request=resp) + self.memory.add_message_to_memory(request=resp) # Return the last response for backward compatibility return responses[-1] @@ -312,6 +318,6 @@ async def add_prepended_conversation_to_memory( # and if not, this won't hurt anything piece.id = uuid4() - self._memory.add_message_to_memory(request=request) + self.memory.add_message_to_memory(request=request) return prepended_conversation diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index 824c104f47..e624e53628 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -134,6 +134,8 @@ async def _upload_blob_async(self, file_name: str, data: bytes, content_type: st # If not, the file will be put in the root of the container. blob_path = f"{blob_prefix}/{file_name}" if blob_prefix else file_name try: + if self._client_async is None: + raise ValueError("Blob storage client not initialized") blob_client = self._client_async.get_blob_client(blob=blob_path) if await blob_client.exists(): logger.info(msg=f"Blob {blob_path} already exists. Deleting it before uploading a new version.") diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 85da9e084c..eb72dbf579 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -168,7 +168,7 @@ def _load_from_path(self, path: str, **kwargs: Any) -> None: **kwargs: Additional keyword arguments to pass to the model loader. """ logger.info(f"Loading model and tokenizer from path: {path}...") - self.tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call, unused-ignore] + self.tokenizer = AutoTokenizer.from_pretrained( path, trust_remote_code=self.trust_remote_code ) self.model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=self.trust_remote_code, **kwargs) @@ -230,23 +230,23 @@ async def load_model_and_tokenizer(self) -> None: ".cache", "huggingface", "hub", - f"models--{self.model_id.replace('/', '--')}", + f"models--{(self.model_id or '').replace('/', '--')}", ) if self.necessary_files is None: # Download all files if no specific files are provided logger.info(f"Downloading all files for {self.model_id}...") - await download_specific_files(self.model_id, None, self.huggingface_token, Path(cache_dir)) + await download_specific_files(self.model_id or "", None, self.huggingface_token, Path(cache_dir)) else: # Download only the necessary files logger.info(f"Downloading specific files for {self.model_id}...") await download_specific_files( - self.model_id, self.necessary_files, self.huggingface_token, Path(cache_dir) + self.model_id or "", self.necessary_files, self.huggingface_token, Path(cache_dir) ) # Load the tokenizer and model from the specified directory logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...") - self.tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call, unused-ignore] + self.tokenizer = AutoTokenizer.from_pretrained( self.model_id, cache_dir=cache_dir, trust_remote_code=self.trust_remote_code ) self.model = AutoModelForCausalLM.from_pretrained( @@ -257,7 +257,7 @@ async def load_model_and_tokenizer(self) -> None: ) # Move the model to the correct device - self.model = self.model.to(self.device) # type: ignore[arg-type] + self.model = self.model.to(self.device) # Debug prints to check types logger.info(f"Model loaded: {type(self.model)}") @@ -309,7 +309,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: try: # Ensure model is on the correct device (should already be the case from `load_model_and_tokenizer`) - self.model.to(self.device) # type: ignore[arg-type] + self.model.to(self.device) # Record the length of the input tokens to later extract only the generated tokens input_length = input_ids.shape[-1] @@ -345,7 +345,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: response = construct_response_from_request( request=request, response_text_pieces=[assistant_response], - prompt_metadata={"model_id": model_identifier}, + prompt_metadata={"model_id": model_identifier or ""}, ) return [response] diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index ef0950c0f3..e79ebd22e9 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -232,7 +232,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Use unified error handling - automatically detects ChatCompletion and validates response = await self._handle_openai_request( - api_call=lambda: self._async_client.chat.completions.create(**body), + api_call=lambda: self._client.chat.completions.create(**body), request=message, ) return [response] diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index e0000c148a..c033037c54 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -145,7 +145,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Use unified error handler - automatically detects Completion and validates response = await self._handle_openai_request( - api_call=lambda: self._async_client.completions.create(**request_params), # type: ignore[call-overload] + api_call=lambda: self._client.completions.create(**request_params), request=message, ) return [response] diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 8734adb776..2ec0cdd958 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -181,7 +181,7 @@ async def _send_generate_request_async(self, message: Message) -> Message: # Use unified error handler for consistent error handling return await self._handle_openai_request( - api_call=lambda: self._async_client.images.generate(**image_generation_args), + api_call=lambda: self._client.images.generate(**image_generation_args), request=message, ) @@ -231,7 +231,7 @@ async def _send_edit_request_async(self, message: Message) -> Message: image_edit_args["style"] = self.style return await self._handle_openai_request( - api_call=lambda: self._async_client.images.edit(**image_edit_args), + api_call=lambda: self._client.images.edit(**image_edit_args), request=message, ) diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 9951b6db92..0b7a7332c8 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -532,7 +532,7 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Use unified error handling - automatically detects Response and validates result = await self._handle_openai_request( - api_call=lambda body=body: self._async_client.responses.create(**body), + api_call=lambda body=body: self._client.responses.create(**body), request=message, ) diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 0549cd8f62..fce2580161 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -60,6 +60,12 @@ class OpenAITarget(PromptTarget): _async_client: Optional[AsyncOpenAI] = None + @property + def _client(self) -> AsyncOpenAI: + if self._async_client is None: + raise ValueError("AsyncOpenAI client is not initialized") + return self._async_client + def __init__( self, *, @@ -425,6 +431,8 @@ async def _handle_openai_request( # Extract MessagePiece for validation and construction (most targets use single piece) request_piece = request.message_pieces[0] if request.message_pieces else None + if request_piece is None: + raise ValueError("No message pieces in request") # Check for content filter via subclass implementation if self._check_content_filter(response): @@ -451,6 +459,10 @@ def model_dump_json(self) -> str: return error_str request_piece = request.message_pieces[0] if request.message_pieces else None + if request_piece is None: + raise ValueError("No message pieces in request") + if request_piece is None: + raise ValueError("No message pieces in request") return self._handle_content_filter_response(_ErrorResponse(), request_piece) except BadRequestError as e: # Handle 400 errors - includes input policy filters and some Azure output-filter 400s @@ -469,6 +481,8 @@ def model_dump_json(self) -> str: ) request_piece = request.message_pieces[0] if request.message_pieces else None + if request_piece is None: + raise ValueError("No message pieces in request") return handle_bad_request_exception( response_text=str(payload), request=request_piece, @@ -582,7 +596,7 @@ def _set_openai_env_configuration_vars(self) -> None: raise NotImplementedError def _warn_url_with_api_path( - self, endpoint_url: str, api_path: str, provider_examples: dict[str, str] = None + self, endpoint_url: str, api_path: str, provider_examples: Optional[dict[str, str]] = None ) -> None: """ Warn if URL includes API-specific path that should be handled by the SDK. diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index 130bf7274a..ece07de5b5 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -132,12 +132,12 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Use unified error handler for consistent error handling response = await self._handle_openai_request( - api_call=lambda: self._async_client.audio.speech.create( - model=body_parameters["model"], # type: ignore[arg-type] - voice=body_parameters["voice"], # type: ignore[arg-type] - input=body_parameters["input"], # type: ignore[arg-type] - response_format=body_parameters.get("response_format"), # type: ignore[arg-type] - speed=body_parameters.get("speed"), # type: ignore[arg-type] + api_call=lambda: self._client.audio.speech.create( + model=body_parameters["model"], + voice=body_parameters["voice"], + input=body_parameters["input"], + response_format=body_parameters.get("response_format"), + speed=body_parameters.get("speed"), ), request=message, ) diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index f09f5bd679..204a306247 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -194,6 +194,8 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: self._validate_request(message=message) text_piece = message.get_piece_by_type(data_type="text") + if text_piece is None: + raise ValueError("No text piece found in message") # Validate video_path pieces for remix mode (does not strip them) self._validate_video_remix_pieces(message=message) @@ -252,7 +254,7 @@ async def _send_text_plus_image_to_video_async( logger.info("Text+Image-to-video mode: Using image as first frame") input_file = await self._prepare_image_input_async(image_piece=image_piece) return await self._handle_openai_request( - api_call=lambda: self._async_client.videos.create_and_poll( + api_call=lambda: self._client.videos.create_and_poll( model=self._model_name, prompt=prompt, size=self._size, @@ -274,7 +276,7 @@ async def _send_text_to_video_async(self, *, prompt: str, request: Message) -> M The response Message with the generated video path. """ return await self._handle_openai_request( - api_call=lambda: self._async_client.videos.create_and_poll( + api_call=lambda: self._client.videos.create_and_poll( model=self._model_name, prompt=prompt, size=self._size, @@ -330,11 +332,11 @@ async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any: Returns: The completed Video object from the OpenAI SDK. """ - video = await self._async_client.videos.remix(video_id, prompt=prompt) + video = await self._client.videos.remix(video_id, prompt=prompt) # Poll until completion if not already done if video.status not in ["completed", "failed"]: - video = await self._async_client.videos.poll(video.id) + video = await self._client.videos.poll(video.id) return video @@ -384,7 +386,7 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> logger.info(f"Video was remixed from: {video.remixed_from_video_id}") # Download video content using SDK - video_response = await self._async_client.videos.download_content(video.id) + video_response = await self._client.videos.download_content(video.id) # Extract bytes from HttpxBinaryResponseContent video_content = video_response.content diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index fe1d3e760f..2b2eec4b78 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -85,14 +85,19 @@ def __init__( endpoint_value = default_values.get_required_value( env_var_name=self.ENDPOINT_URI_ENVIRONMENT_VARIABLE, passed_value=endpoint ) + if endpoint_value is None: + raise ValueError("Endpoint value is required") super().__init__(max_requests_per_minute=max_requests_per_minute, endpoint=endpoint_value) - self._api_version = api_version + self._api_version = api_version or "2024-09-01" # API key is required - either from parameter or environment variable - self._api_key = default_values.get_required_value( + _api_key_value = default_values.get_required_value( env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key ) + if _api_key_value is None: + raise ValueError("API key is required") + self._api_key = _api_key_value self._force_entry_field: PromptShieldEntryField = field diff --git a/pyrit/prompt_target/rpc_client.py b/pyrit/prompt_target/rpc_client.py index f3012a39fb..ccaae11666 100644 --- a/pyrit/prompt_target/rpc_client.py +++ b/pyrit/prompt_target/rpc_client.py @@ -76,8 +76,12 @@ def wait_for_prompt(self) -> MessagePiece: Raises: RPCClientStoppedException: If the client has been stopped. """ + if self._prompt_received_sem is None: + raise ValueError("Semaphore not initialized") self._prompt_received_sem.acquire() if self._is_running: + if self._prompt_received is None: + raise ValueError("No prompt received") return self._prompt_received raise RPCClientStoppedException @@ -88,6 +92,8 @@ def send_message(self, response: bool) -> None: Args: response (bool): True if the prompt is safe, False if unsafe. """ + if self._prompt_received is None: + raise ValueError("No prompt received") score = Score( score_value=str(response), score_type="true_false", @@ -101,6 +107,8 @@ def send_message(self, response: bool) -> None: class_module="pyrit.prompt_target.rpc_client", ), ) + if self._c is None: + raise ValueError("RPC connection not initialized") self._c.root.receive_score(score) def _wait_for_server_avaible(self) -> None: @@ -114,6 +122,8 @@ def stop(self) -> None: Stop the client. """ # Send a signal to the thread to stop + if self._shutdown_event is None: + raise ValueError("Shutdown event not initialized") self._shutdown_event.set() if self._bgsrv_thread is not None: @@ -130,11 +140,15 @@ def reconnect(self) -> None: def _receive_prompt(self, message_piece: MessagePiece, task: Optional[str] = None) -> None: print(f"Received prompt: {message_piece}") self._prompt_received = message_piece + if self._prompt_received_sem is None: + raise ValueError("Semaphore not initialized") self._prompt_received_sem.release() def _ping(self) -> None: try: while self._is_running: + if self._c is None: + raise ValueError("RPC connection not initialized") self._c.root.receive_ping() time.sleep(1.5) if not self._is_running: @@ -152,15 +166,23 @@ def _bgsrv_lifecycle(self) -> None: self._ping_thread.start() # Register callback + if self._c is None: + raise ValueError("RPC connection not initialized") self._c.root.callback_score_prompt(self._receive_prompt) # Wait for the server to be disconnected + if self._shutdown_event is None: + raise ValueError("Shutdown event not initialized") self._shutdown_event.wait() self._is_running = False # Release the semaphore in case it was waiting + if self._prompt_received_sem is None: + raise ValueError("Semaphore not initialized") self._prompt_received_sem.release() + if self._ping_thread is None: + raise ValueError("Ping thread not initialized") self._ping_thread.join() # Avoid calling stop() twice if the server is already stopped. This can happen if the server is stopped diff --git a/pyrit/prompt_target/text_target.py b/pyrit/prompt_target/text_target.py index d47e5d6656..a00d617c72 100644 --- a/pyrit/prompt_target/text_target.py +++ b/pyrit/prompt_target/text_target.py @@ -76,7 +76,7 @@ def import_scores_from_csv(self, csv_file_path: Path) -> list[MessagePiece]: original_value=row["value"], original_value_data_type=row.get("data_type", None), # type: ignore[arg-type] conversation_id=row.get("conversation_id", None), - sequence=int(sequence_str) if sequence_str else None, + sequence=int(sequence_str) if sequence_str else 0, labels=labels, response_error=row.get("response_error", None), # type: ignore[arg-type] prompt_target_identifier=self.get_identifier(), diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index ad9ed2c641..90494b3dc9 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -589,7 +589,7 @@ def _validate_request(self, *, message: Message) -> None: if piece_type == "image_path": mime_type = DataTypeSerializer.get_mime_type(piece.converted_value) - if not mime_type.startswith("image/"): + if not mime_type or not mime_type.startswith("image/"): raise ValueError( f"Invalid image format for image_path: {piece.converted_value}. " f"Detected MIME type: {mime_type}." diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index cea7e16203..8accc2ab03 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -91,7 +91,8 @@ def __init__(self, *, discovery_path: Optional[Path] = None, lazy_discovery: boo self._discovery_path = Path(PYRIT_PATH) / "setup" / "initializers" # At this point _discovery_path is guaranteed to be a Path - assert self._discovery_path is not None + if self._discovery_path is None: + raise ValueError("self._discovery_path is not initialized") # Track file paths for collision detection and resolution self._initializer_paths: dict[str, Path] = {} diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index 443dd6c43f..e670c21922 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -612,7 +612,8 @@ async def _execute_scenario_async(self) -> ScenarioResult: # Type narrowing: _scenario_result_id is guaranteed to be non-None at this point # (verified in run_async before calling this method) - assert self._scenario_result_id is not None + if self._scenario_result_id is None: + raise ValueError("self._scenario_result_id is not initialized") scenario_result_id: str = self._scenario_result_id # Increment number_tries at the start of each run diff --git a/pyrit/scenario/scenarios/airt/content_harms.py b/pyrit/scenario/scenarios/airt/content_harms.py index 0fcc816ad4..9eea1bed4c 100644 --- a/pyrit/scenario/scenarios/airt/content_harms.py +++ b/pyrit/scenario/scenarios/airt/content_harms.py @@ -201,7 +201,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) @@ -212,7 +212,7 @@ def _get_default_scorer(self) -> TrueFalseInverterScorer: scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=0.9, ) diff --git a/pyrit/scenario/scenarios/airt/cyber.py b/pyrit/scenario/scenarios/airt/cyber.py index be084e6e90..6d126bd336 100644 --- a/pyrit/scenario/scenarios/airt/cyber.py +++ b/pyrit/scenario/scenarios/airt/cyber.py @@ -170,7 +170,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: presence_of_malware = SelfAskTrueFalseScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ), true_false_question_path=SCORER_SEED_PROMPT_PATH / "true_false_question" / "malware.yaml", @@ -180,7 +180,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ) ) @@ -200,7 +200,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) diff --git a/pyrit/scenario/scenarios/airt/jailbreak.py b/pyrit/scenario/scenarios/airt/jailbreak.py index c3e1e72dbe..094a24f8d7 100644 --- a/pyrit/scenario/scenarios/airt/jailbreak.py +++ b/pyrit/scenario/scenarios/airt/jailbreak.py @@ -125,7 +125,7 @@ def __init__( scenario_result_id: Optional[str] = None, num_templates: Optional[int] = None, num_attempts: int = 1, - jailbreak_names: list[str] = None, + jailbreak_names: Optional[list[str]] = None, ) -> None: """ Initialize the jailbreak scenario. @@ -207,7 +207,7 @@ def _get_default_objective_scorer(self) -> TrueFalseScorer: scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ) ) @@ -223,7 +223,7 @@ def _create_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) @@ -316,7 +316,7 @@ async def _get_atomic_attack_from_strategy_async( template_name = Path(jailbreak_template_name).stem return AtomicAttack( - atomic_attack_name=f"jailbreak_{template_name}", attack=attack, seed_groups=self._seed_groups + atomic_attack_name=f"jailbreak_{template_name}", attack=attack, seed_groups=self._seed_groups or [] ) async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: diff --git a/pyrit/scenario/scenarios/airt/leakage.py b/pyrit/scenario/scenarios/airt/leakage.py index 61c1f13e13..f8f23e57ec 100644 --- a/pyrit/scenario/scenarios/airt/leakage.py +++ b/pyrit/scenario/scenarios/airt/leakage.py @@ -196,7 +196,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: presence_of_leakage = SelfAskTrueFalseScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ), true_false_question_path=SCORER_SEED_PROMPT_PATH / "true_false_question" / "leakage.yaml", @@ -209,7 +209,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ) ) @@ -229,7 +229,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index 16320231c3..4f8627ecec 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -301,7 +301,7 @@ def _resolve_seed_groups(self) -> ResolvedSeedData: if harm_category_filter: seed_groups = self._filter_by_harm_category( - seed_groups=seed_groups, + seed_groups=seed_groups or [], harm_category=harm_category_filter, ) logger.info( @@ -367,7 +367,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=0.7, ) @@ -407,7 +407,7 @@ def _get_scorer(self, subharm: Optional[str] = None) -> FloatScaleThresholdScore endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") azure_openai_chat_target = OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ) @@ -474,7 +474,7 @@ def _create_single_turn_attacks( AtomicAttack( atomic_attack_name="psychosocial_single_turn", attack=prompt_sending, - seed_groups=seed_groups, + seed_groups=seed_groups or [], memory_labels=self._memory_labels, ) ) @@ -488,7 +488,7 @@ def _create_single_turn_attacks( AtomicAttack( atomic_attack_name="psychosocial_role_play", attack=role_play, - seed_groups=seed_groups, + seed_groups=seed_groups or [], memory_labels=self._memory_labels, ) ) @@ -525,7 +525,7 @@ def _create_multi_turn_attack( return AtomicAttack( atomic_attack_name="psychosocial_crescendo_turn", attack=crescendo, - seed_groups=seed_groups, + seed_groups=seed_groups or [], memory_labels=self._memory_labels, ) diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index 98ae7b338d..113cdf0045 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -196,7 +196,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: scam_materials = SelfAskTrueFalseScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=0.9, ), @@ -207,7 +207,7 @@ def _get_default_objective_scorer(self) -> TrueFalseCompositeScorer: scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), ) ) @@ -225,7 +225,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) @@ -313,7 +313,7 @@ def _get_atomic_attack_from_strategy(self, strategy: str) -> AtomicAttack: return AtomicAttack( atomic_attack_name=f"scam_{strategy}", attack=attack_strategy, - seed_groups=self._seed_groups, + seed_groups=self._seed_groups or [], memory_labels=self._memory_labels, ) diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index afbbfabd21..173fb2cadc 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -351,7 +351,7 @@ def _get_default_adversarial_target(self) -> OpenAIChatTarget: endpoint = os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") return OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=1.2, ) @@ -366,7 +366,7 @@ def _get_default_scoring_config(self) -> AttackScoringConfig: scorer=SelfAskRefusalScorer( chat_target=OpenAIChatTarget( endpoint=endpoint, - api_key=get_azure_openai_auth(endpoint), + api_key=get_azure_openai_auth(endpoint or ""), model_name=os.environ.get("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"), temperature=0.9, ) @@ -534,7 +534,7 @@ def _get_attack( # Create the adversarial config from self._adversarial_target attack_adversarial_config = AttackAdversarialConfig(target=self._adversarial_chat) - kwargs["attack_adversarial_config"] = attack_adversarial_config + kwargs["attack_adversarial_config"] = attack_adversarial_config # type: ignore[assignment] # Add attack-specific kwargs if provided if attack_kwargs: diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index 1395e3b968..9c7e7e3f46 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -39,7 +39,7 @@ def _is_compliant_wav(input_path: str, *, sample_rate: int, channels: int) -> bo is_pcm_s16 = codec_name == "pcm_s16le" is_correct_rate = stream.rate == sample_rate is_correct_channels = stream.channels == channels - return is_pcm_s16 and is_correct_rate and is_correct_channels + return bool(is_pcm_s16 and is_correct_rate and is_correct_channels) except Exception: return False diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 16aa3d75ab..1ca064d5d4 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -180,7 +180,7 @@ async def evaluate_async( file_mapping: Optional["ScorerEvalDatasetFiles"] = None, *, num_scorer_trials: int = 3, - update_registry_behavior: "RegistryUpdateBehavior" = None, + update_registry_behavior: "Optional[RegistryUpdateBehavior]" = None, max_concurrency: int = 10, ) -> Optional["ScorerMetrics"]: """ diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index ab72c7ba16..18388d943e 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -140,6 +140,18 @@ class LikertScalePaths(enum.Enum): Path(HARM_DEFINITION_PATH, "phishing.yaml").resolve(), None, ) + AI_SUPPLY_CHAIN_SCALE = ( + Path(HARM_DEFINITION_PATH, "ai_supply_chain.yaml").resolve(), + None, + ) + AI_SYSTEM_TRANSPARENCY_SCALE = ( + Path(HARM_DEFINITION_PATH, "ai_system_transparency.yaml").resolve(), + None, + ) + AI_GOVERNANCE_FAILURE_SCALE = ( + Path(HARM_DEFINITION_PATH, "ai_governance_failure.yaml").resolve(), + None, + ) @property def path(self) -> Path: diff --git a/pyrit/score/human/human_in_the_loop_gradio.py b/pyrit/score/human/human_in_the_loop_gradio.py index 3237ec7028..a5f802dd14 100644 --- a/pyrit/score/human/human_in_the_loop_gradio.py +++ b/pyrit/score/human/human_in_the_loop_gradio.py @@ -105,6 +105,8 @@ def retrieve_score(self, request_prompt: MessagePiece, *, objective: Optional[st self._rpc_server.wait_for_client() self._rpc_server.send_score_prompt(request_prompt) score = self._rpc_server.wait_for_score() + if score is None: + raise ValueError("No score received from RPC server") score.scorer_class_identifier = self.get_identifier() return [score] diff --git a/pyrit/score/printer/console_scorer_printer.py b/pyrit/score/printer/console_scorer_printer.py index c8270a10a9..a0952fb26b 100644 --- a/pyrit/score/printer/console_scorer_printer.py +++ b/pyrit/score/printer/console_scorer_printer.py @@ -77,16 +77,16 @@ def _get_quality_color( """ if higher_is_better: if value >= good_threshold: - return Fore.GREEN # type: ignore[no-any-return] + return Fore.GREEN if value < bad_threshold: - return Fore.RED # type: ignore[no-any-return] - return Fore.CYAN # type: ignore[no-any-return] + return Fore.RED + return Fore.CYAN # Lower is better (e.g., MAE, score time) if value <= good_threshold: - return Fore.GREEN # type: ignore[no-any-return] + return Fore.GREEN if value > bad_threshold: - return Fore.RED # type: ignore[no-any-return] - return Fore.CYAN # type: ignore[no-any-return] + return Fore.RED + return Fore.CYAN def print_objective_scorer(self, *, scorer_identifier: ComponentIdentifier) -> None: """ diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index c1ad1910a6..ce27cb76a7 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -268,7 +268,7 @@ async def evaluate_async( file_mapping: Optional[ScorerEvalDatasetFiles] = None, *, num_scorer_trials: int = 3, - update_registry_behavior: RegistryUpdateBehavior = None, + update_registry_behavior: Optional[RegistryUpdateBehavior] = None, max_concurrency: int = 10, ) -> Optional[ScorerMetrics]: """ @@ -355,7 +355,7 @@ async def score_text_async(self, text: str, *, objective: Optional[str] = None) ] ) - request.message_pieces[0].id = None + request.message_pieces[0].id = None # type: ignore[assignment] return await self.score_async(request, objective=objective) async def score_image_async(self, image_path: str, *, objective: Optional[str] = None) -> list[Score]: @@ -379,7 +379,7 @@ async def score_image_async(self, image_path: str, *, objective: Optional[str] = ] ) - request.message_pieces[0].id = None + request.message_pieces[0].id = None # type: ignore[assignment] return await self.score_async(request, objective=objective) async def score_prompts_batch_async( diff --git a/pyrit/score/true_false/prompt_shield_scorer.py b/pyrit/score/true_false/prompt_shield_scorer.py index 652623d8cd..4db52eb17c 100644 --- a/pyrit/score/true_false/prompt_shield_scorer.py +++ b/pyrit/score/true_false/prompt_shield_scorer.py @@ -119,17 +119,14 @@ def _parse_response_to_boolean_list(self, response: str) -> list[bool]: """ response_json: dict[str, Any] = json.loads(response) - user_detections = [] - document_detections = [] - user_prompt_attack: dict[str, bool] = response_json.get("userPromptAnalysis", False) documents_attack: list[dict[str, Any]] = response_json.get("documentsAnalysis", False) - user_detections = [False] if not user_prompt_attack else [user_prompt_attack.get("attackDetected")] + user_detections: list[bool] = [False] if not user_prompt_attack else [bool(user_prompt_attack.get("attackDetected"))] if not documents_attack: - document_detections = [False] + document_detections: list[bool] = [False] else: - document_detections = [document.get("attackDetected") for document in documents_attack] + document_detections = [bool(document.get("attackDetected")) for document in documents_attack] return user_detections + document_detections diff --git a/pyrit/score/true_false/self_ask_true_false_scorer.py b/pyrit/score/true_false/self_ask_true_false_scorer.py index da1054274d..d79060fcb4 100644 --- a/pyrit/score/true_false/self_ask_true_false_scorer.py +++ b/pyrit/score/true_false/self_ask_true_false_scorer.py @@ -140,6 +140,8 @@ def __init__( if true_false_question_path: true_false_question_path = verify_and_resolve_path(true_false_question_path) true_false_question = yaml.safe_load(true_false_question_path.read_text(encoding="utf-8")) + if true_false_question is None: + raise ValueError("Failed to load true_false_question YAML") for key in ["category", "true_description", "false_description"]: if key not in true_false_question: diff --git a/pyrit/score/true_false/true_false_composite_scorer.py b/pyrit/score/true_false/true_false_composite_scorer.py index c66c24d437..45d0dc4cdb 100644 --- a/pyrit/score/true_false/true_false_composite_scorer.py +++ b/pyrit/score/true_false/true_false_composite_scorer.py @@ -113,7 +113,8 @@ async def _score_async( # Ensure the message piece has an ID piece_id = message.message_pieces[0].id - assert piece_id is not None, "Message piece must have an ID" + if piece_id is None: + raise ValueError("Message piece must have an ID") return_score = Score( score_value=str(result.value), diff --git a/pyrit/setup/initializers/airt.py b/pyrit/setup/initializers/airt.py index 96740565d8..6ac0c50286 100644 --- a/pyrit/setup/initializers/airt.py +++ b/pyrit/setup/initializers/airt.py @@ -109,8 +109,10 @@ async def initialize_async(self) -> None: scorer_model_name = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2") # Type assertions - safe because validate() already checked these - assert converter_endpoint is not None - assert scorer_endpoint is not None + if converter_endpoint is None: + raise ValueError("converter_endpoint is not initialized") + if scorer_endpoint is None: + raise ValueError("scorer_endpoint is not initialized") # model name can be empty in certain cases (e.g., custom model deployments that don't need model name) # Check for API keys first, fall back to Entra auth if not set @@ -125,7 +127,7 @@ async def initialize_async(self) -> None: # 1. Setup converter target self._setup_converter_target( - endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name + endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name or "" ) # 2. Setup scorers @@ -133,12 +135,12 @@ async def initialize_async(self) -> None: endpoint=scorer_endpoint, api_key=scorer_api_key, content_safety_api_key=content_safety_api_key, - model_name=scorer_model_name, + model_name=scorer_model_name or "", ) # 3. Setup adversarial targets self._setup_adversarial_targets( - endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name + endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name or "" ) def _setup_converter_target(self, *, endpoint: str, api_key: str, model_name: str) -> None: diff --git a/pyrit/setup/initializers/components/scorers.py b/pyrit/setup/initializers/components/scorers.py index 06b304ebc6..c70659d85a 100644 --- a/pyrit/setup/initializers/components/scorers.py +++ b/pyrit/setup/initializers/components/scorers.py @@ -159,23 +159,24 @@ async def initialize_async(self) -> None: unsafe_temp9: Optional[PromptChatTarget] = target_registry.get_instance_by_name(GPT4O_UNSAFE_TEMP9_TARGET) # type: ignore[assignment] # Refusal Scorers - self._try_register(scorer_registry, REFUSAL_GPT4O, lambda: SelfAskRefusalScorer(chat_target=gpt4o), gpt4o) + self._try_register(scorer_registry, REFUSAL_GPT4O, lambda: SelfAskRefusalScorer(chat_target=gpt4o), # type: ignore[arg-type] + gpt4o,) self._try_register( scorer_registry, INVERTED_REFUSAL_GPT4O, - lambda: TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=gpt4o)), + lambda: TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=gpt4o)), # type: ignore[arg-type] gpt4o, ) self._try_register( scorer_registry, INVERTED_REFUSAL_GPT4O_UNSAFE, - lambda: TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=unsafe)), + lambda: TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=unsafe)), # type: ignore[arg-type] unsafe, ) self._try_register( scorer_registry, INVERTED_REFUSAL_GPT4O_UNSAFE_TEMP9, - lambda: TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=unsafe_temp9)), + lambda: TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=unsafe_temp9)), # type: ignore[arg-type] unsafe_temp9, ) @@ -204,7 +205,7 @@ async def initialize_async(self) -> None: aggregator=TrueFalseScoreAggregator.AND, scorers=[ FloatScaleThresholdScorer(scorer=AzureContentFilterScorer(), threshold=0.5), - TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=unsafe_temp9)), + TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=unsafe_temp9)), # type: ignore[arg-type] ], ), unsafe_temp9, @@ -213,7 +214,7 @@ async def initialize_async(self) -> None: scorer_registry, SCALE_GPT4O_TEMP9_THRESHOLD_09, lambda: FloatScaleThresholdScorer( - scorer=SelfAskScaleScorer(chat_target=gpt4o_temp9), + scorer=SelfAskScaleScorer(chat_target=gpt4o_temp9), # type: ignore[arg-type] threshold=0.9, ), gpt4o_temp9, @@ -225,10 +226,10 @@ async def initialize_async(self) -> None: aggregator=TrueFalseScoreAggregator.AND, scorers=[ FloatScaleThresholdScorer( - scorer=SelfAskScaleScorer(chat_target=gpt4o_temp9), + scorer=SelfAskScaleScorer(chat_target=gpt4o_temp9), # type: ignore[arg-type] threshold=0.9, ), - TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=gpt4o)), + TrueFalseInverterScorer(scorer=SelfAskRefusalScorer(chat_target=gpt4o)), # type: ignore[arg-type] ], ), gpt4o_temp9, @@ -258,7 +259,7 @@ async def initialize_async(self) -> None: scorer_registry, TASK_ACHIEVED_GPT4O_TEMP9, lambda: SelfAskTrueFalseScorer( - chat_target=gpt4o_temp9, + chat_target=gpt4o_temp9, # type: ignore[arg-type] true_false_question_path=TrueFalseQuestionPaths.TASK_ACHIEVED.value, ), gpt4o_temp9, @@ -267,7 +268,7 @@ async def initialize_async(self) -> None: scorer_registry, TASK_ACHIEVED_REFINED_GPT4O_TEMP9, lambda: SelfAskTrueFalseScorer( - chat_target=gpt4o_temp9, + chat_target=gpt4o_temp9, # type: ignore[arg-type] true_false_question_path=TrueFalseQuestionPaths.TASK_ACHIEVED_REFINED.value, ), gpt4o_temp9, @@ -280,7 +281,7 @@ async def initialize_async(self) -> None: self._try_register( scorer_registry, scorer_name, - lambda s=scale: SelfAskLikertScorer(chat_target=gpt4o, likert_scale=s), # type: ignore[misc] + lambda s=scale: SelfAskLikertScorer(chat_target=gpt4o, likert_scale=s), # type: ignore[arg-type, misc] gpt4o, ) diff --git a/pyrit/show_versions.py b/pyrit/show_versions.py index e19fde71ff..301faebdd7 100644 --- a/pyrit/show_versions.py +++ b/pyrit/show_versions.py @@ -56,7 +56,7 @@ def _get_deps_info() -> dict[str, str | None]: from pyrit import __version__ - deps_info = {"pyrit": __version__} + deps_info: dict[str, str | None] = {"pyrit": __version__} from importlib.metadata import PackageNotFoundError, version @@ -78,5 +78,5 @@ def show_versions() -> None: print(f"{k:>10}: {stat}") print("\nPython dependencies:") - for k, stat in deps_info.items(): - print(f"{k:>13}: {stat}") + for k, stat_or_none in deps_info.items(): + print(f"{k:>13}: {stat_or_none}") diff --git a/pyrit/ui/rpc.py b/pyrit/ui/rpc.py index bb9828c11a..8c43e4fe77 100644 --- a/pyrit/ui/rpc.py +++ b/pyrit/ui/rpc.py @@ -97,6 +97,8 @@ def is_client_ready(self) -> bool: def send_score_prompt(self, prompt: MessagePiece, task: Optional[str] = None) -> None: if not self.is_client_ready(): raise RPCClientNotReadyException + if self._callback_score_prompt is None: + raise ValueError("self._callback_score_prompt is not initialized") self._callback_score_prompt(prompt, task) def is_ping_missed(self) -> bool: @@ -165,6 +167,8 @@ def stop(self) -> None: """ self.stop_request() if self._server is not None: + if self._server_thread is None: + raise ValueError("self._server_thread is not initialized") self._server_thread.join() if self._is_alive_thread is not None: @@ -201,7 +205,7 @@ def send_score_prompt(self, prompt: MessagePiece, task: Optional[str] = None) -> self._rpc_service.send_score_prompt(prompt, task) - def wait_for_score(self) -> Score: + def wait_for_score(self) -> Optional[Score]: """ Wait for the client to send a score. Should always return a score, but if the synchronisation fails it will return None. @@ -214,6 +218,8 @@ def wait_for_score(self) -> Score: raise RPCServerStoppedException score_ref = self._rpc_service.pop_score_received() + if self._client_ready_semaphore is None: + raise ValueError("self._client_ready_semaphore is not initialized") self._client_ready_semaphore.release() if score_ref is None: return None diff --git a/pyrit/ui/rpc_client.py b/pyrit/ui/rpc_client.py index d6cae64e32..51a1535d1f 100644 --- a/pyrit/ui/rpc_client.py +++ b/pyrit/ui/rpc_client.py @@ -52,12 +52,18 @@ def start(self) -> None: self._bgsrv_thread.start() def wait_for_prompt(self) -> MessagePiece: + if self._prompt_received_sem is None: + raise ValueError("Semaphore not initialized") self._prompt_received_sem.acquire() if self._is_running: + if self._prompt_received is None: + raise ValueError("No prompt received") return self._prompt_received raise RPCClientStoppedException def send_message(self, response: bool) -> None: + if self._prompt_received is None: + raise ValueError("No prompt received") score = Score( score_value=str(response), score_type="true_false", @@ -71,6 +77,8 @@ def send_message(self, response: bool) -> None: class_module="pyrit.ui.rpc_client", ), ) + if self._c is None: + raise ValueError("RPC connection not initialized") self._c.root.receive_score(score) def _wait_for_server_avaible(self) -> None: @@ -84,6 +92,8 @@ def stop(self) -> None: Stop the client. """ # Send a signal to the thread to stop + if self._shutdown_event is None: + raise ValueError("Shutdown event not initialized") self._shutdown_event.set() if self._bgsrv_thread is not None: @@ -100,11 +110,15 @@ def reconnect(self) -> None: def _receive_prompt(self, message_piece: MessagePiece, task: Optional[str] = None) -> None: print(f"Received prompt: {message_piece}") self._prompt_received = message_piece + if self._prompt_received_sem is None: + raise ValueError("Semaphore not initialized") self._prompt_received_sem.release() def _ping(self) -> None: try: while self._is_running: + if self._c is None: + raise ValueError("RPC connection not initialized") self._c.root.receive_ping() time.sleep(1.5) if not self._is_running: @@ -122,15 +136,23 @@ def _bgsrv_lifecycle(self) -> None: self._ping_thread.start() # Register callback + if self._c is None: + raise ValueError("RPC connection not initialized") self._c.root.callback_score_prompt(self._receive_prompt) # Wait for the server to be disconnected + if self._shutdown_event is None: + raise ValueError("Shutdown event not initialized") self._shutdown_event.wait() self._is_running = False # Release the semaphore in case it was waiting + if self._prompt_received_sem is None: + raise ValueError("Semaphore not initialized") self._prompt_received_sem.release() + if self._ping_thread is None: + raise ValueError("Ping thread not initialized") self._ping_thread.join() # Avoid calling stop() twice if the server is already stopped. This can happen if the server is stopped