diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 864c97cc4..47200f75d 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -25,6 +25,7 @@ ) from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation from pyrit.models.identifiers import Identifier +from pyrit.models.json_response_config import JsonResponseConfig from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError from pyrit.models.message import ( Message, @@ -68,6 +69,7 @@ "group_message_pieces_into_conversations", "Identifier", "ImagePathDataTypeSerializer", + "JsonResponseConfig", "Message", "MessagePiece", "PromptDataType", diff --git a/pyrit/models/json_response_config.py b/pyrit/models/json_response_config.py new file mode 100644 index 000000000..ff0ad993c --- /dev/null +++ b/pyrit/models/json_response_config.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +@dataclass +class JsonResponseConfig: + enabled: bool = False + schema: Optional[Dict[str, Any]] = None + schema_name: str = "CustomSchema" + strict: bool = True + + @classmethod + def from_metadata(cls, *, metadata: Optional[Dict[str, Any]]) -> "JsonResponseConfig": + if not metadata: + return cls(enabled=False) + + response_format = metadata.get("response_format") + if response_format != "json": + return cls(enabled=False) + + schema_val = metadata.get("json_schema") + if schema_val: + if isinstance(schema_val, str): + try: + schema = json.loads(schema_val) if schema_val else None + except json.JSONDecodeError: + raise ValueError(f"Invalid JSON schema provided: {schema_val}") + else: + schema = schema_val + + return cls( + enabled=True, + schema=schema, + schema_name=metadata.get("schema_name", "CustomSchema"), + strict=metadata.get("strict", True), + ) + + return cls(enabled=True) diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index 2faf9f9a0..ee4146f85 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -4,7 +4,7 @@ import abc from typing import Optional -from pyrit.models import MessagePiece +from pyrit.models import JsonResponseConfig, MessagePiece from pyrit.prompt_target import PromptTarget @@ -75,16 +75,19 @@ def is_response_format_json(self, message_piece: MessagePiece) -> bool: include a "response_format" key. Returns: - bool: True if the response format is JSON and supported, False otherwise. + bool: True if the response format is JSON, False otherwise. Raises: ValueError: If "json" response format is requested but unsupported. """ - if message_piece.prompt_metadata: - response_format = message_piece.prompt_metadata.get("response_format") - if response_format == "json": - if not self.is_json_response_supported(): - target_name = self.get_identifier()["__type__"] - raise ValueError(f"This target {target_name} does not support JSON response format.") - return True - return False + config = self.get_json_response_config(message_piece=message_piece) + return config.enabled + + def get_json_response_config(self, *, message_piece: MessagePiece) -> JsonResponseConfig: + config = JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata) + + if config.enabled and not self.is_json_response_supported(): + target_name = self.get_identifier()["__type__"] + raise ValueError(f"This target {target_name} does not support JSON response format.") + + return config diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index b28cf2565..832c1da3b 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -3,7 +3,7 @@ import json import logging -from typing import Any, MutableSequence, Optional +from typing import Any, Dict, MutableSequence, Optional from pyrit.common import convert_local_image_to_data_url from pyrit.exceptions import ( @@ -14,6 +14,7 @@ from pyrit.models import ( ChatMessage, ChatMessageListDictContent, + JsonResponseConfig, Message, MessagePiece, construct_response_from_request, @@ -245,8 +246,11 @@ async def _build_chat_messages_for_multi_modal_async(self, conversation: Mutable chat_messages.append(chat_message.model_dump(exclude_none=True)) return chat_messages - async def _construct_request_body(self, conversation: MutableSequence[Message], is_json_response: bool) -> dict: + async def _construct_request_body( + self, *, conversation: MutableSequence[Message], json_config: JsonResponseConfig + ) -> dict: messages = await self._build_chat_messages_async(conversation) + response_format = self._build_response_format(json_config) body_parameters = { "model": self._model_name, @@ -260,7 +264,7 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], "seed": self._seed, "n": self._n, "messages": messages, - "response_format": {"type": "json_object"} if is_json_response else None, + "response_format": response_format, } if self._extra_body_parameters: @@ -276,7 +280,6 @@ def _construct_message_from_openai_json( open_ai_str_response: str, message_piece: MessagePiece, ) -> Message: - try: response = json.loads(open_ai_str_response) except json.JSONDecodeError as e: @@ -322,3 +325,19 @@ def _validate_request(self, *, message: Message) -> None: for prompt_data_type in converted_prompt_data_types: if prompt_data_type not in ["text", "image_path"]: raise ValueError(f"This target only supports text and image_path. Received: {prompt_data_type}.") + + def _build_response_format(self, json_config: JsonResponseConfig) -> Optional[Dict[str, Any]]: + if not json_config.enabled: + return None + + if json_config.schema: + return { + "type": "json_schema", + "json_schema": { + "name": json_config.schema_name, + "schema": json_config.schema, + "strict": json_config.strict, + }, + } + + return {"type": "json_object"} diff --git a/pyrit/prompt_target/openai/openai_chat_target_base.py b/pyrit/prompt_target/openai/openai_chat_target_base.py index 754e231be..96c23e3ed 100644 --- a/pyrit/prompt_target/openai/openai_chat_target_base.py +++ b/pyrit/prompt_target/openai/openai_chat_target_base.py @@ -15,6 +15,7 @@ ) from pyrit.exceptions.exception_classes import RateLimitException from pyrit.models import ( + JsonResponseConfig, Message, MessagePiece, ) @@ -84,9 +85,9 @@ def __init__( super().__init__(**kwargs) if temperature is not None and (temperature < 0 or temperature > 2): - raise PyritException("temperature must be between 0 and 2 (inclusive).") + raise PyritException(message="temperature must be between 0 and 2 (inclusive).") if top_p is not None and (top_p < 0 or top_p > 1): - raise PyritException("top_p must be between 0 and 1 (inclusive).") + raise PyritException(message="top_p must be between 0 and 1 (inclusive).") self._temperature = temperature self._top_p = top_p @@ -110,14 +111,14 @@ async def send_prompt_async(self, *, message: Message) -> Message: message_piece: MessagePiece = message.message_pieces[0] - is_json_response = self.is_response_format_json(message_piece) + json_response_config = self.get_json_response_config(message_piece=message_piece) conversation = self._memory.get_conversation(conversation_id=message_piece.conversation_id) conversation.append(message) logger.info(f"Sending the following prompt to the prompt target: {message}") - body = await self._construct_request_body(conversation=conversation, is_json_response=is_json_response) + body = await self._construct_request_body(conversation=conversation, json_config=json_response_config) try: str_response: httpx.Response = await net_utility.make_request_and_raise_if_error_async( @@ -159,7 +160,9 @@ async def send_prompt_async(self, *, message: Message) -> Message: return response - async def _construct_request_body(self, conversation: MutableSequence[Message], is_json_response: bool) -> dict: + async def _construct_request_body( + self, *, conversation: MutableSequence[Message], json_config: JsonResponseConfig + ) -> dict: raise NotImplementedError def _construct_message_from_openai_json( diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index b0405e41d..edede5553 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -21,6 +21,7 @@ handle_bad_request_exception, ) from pyrit.models import ( + JsonResponseConfig, Message, MessagePiece, PromptDataType, @@ -291,7 +292,9 @@ def _translate_roles(self, conversation: List[Dict[str, Any]]) -> None: request["role"] = "developer" return - async def _construct_request_body(self, conversation: MutableSequence[Message], is_json_response: bool) -> dict: + async def _construct_request_body( + self, *, conversation: MutableSequence[Message], json_config: JsonResponseConfig + ) -> dict: """ Construct the request body to send to the Responses API. @@ -300,6 +303,8 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], """ input_items = await self._build_input_for_multi_modal_async(conversation) + text_format = self._build_text_format(json_config=json_config) + body_parameters = { "model": self._model_name, "max_output_tokens": self._max_output_tokens, @@ -308,7 +313,7 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], "stream": False, "input": input_items, # Correct JSON response format per Responses API - "response_format": {"type": "json_object"} if is_json_response else None, + "text": text_format, } if self._extra_body_parameters: @@ -317,6 +322,23 @@ async def _construct_request_body(self, conversation: MutableSequence[Message], # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} + def _build_text_format(self, json_config: JsonResponseConfig) -> Optional[Dict[str, Any]]: + if not json_config.enabled: + return None + + if json_config.schema: + return { + "format": { + "type": "json_schema", + "name": json_config.schema_name, + "schema": json_config.schema, + "strict": json_config.strict, + } + } + + logger.info("Using json_object format without schema - consider providing a schema for better results") + return {"format": {"type": "json_object"}} + def _construct_message_from_openai_json( self, *, diff --git a/tests/integration/targets/test_openai_responses_gpt5.py b/tests/integration/targets/test_openai_responses_gpt5.py index 17c1cca55..1301815f1 100644 --- a/tests/integration/targets/test_openai_responses_gpt5.py +++ b/tests/integration/targets/test_openai_responses_gpt5.py @@ -2,24 +2,30 @@ # Licensed under the MIT license. +import json import os import uuid +import jsonschema import pytest from pyrit.models import MessagePiece from pyrit.prompt_target import OpenAIResponseTarget -@pytest.mark.asyncio -async def test_openai_responses_gpt5(sqlite_instance): - args = { +@pytest.fixture() +def gpt5_args(): + return { "endpoint": os.getenv("AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT"), "model_name": os.getenv("AZURE_OPENAI_GPT5_MODEL"), - "api_key": os.getenv("AZURE_OPENAI_GPT5_KEY"), + # "api_key": os.getenv("AZURE_OPENAI_GPT5_KEY"), + "use_entra_auth": True, } - target = OpenAIResponseTarget(**args) + +@pytest.mark.asyncio +async def test_openai_responses_gpt5(sqlite_instance, gpt5_args): + target = OpenAIResponseTarget(**gpt5_args) conv_id = str(uuid.uuid4()) @@ -46,3 +52,86 @@ async def test_openai_responses_gpt5(sqlite_instance): assert result.message_pieces[1].role == "assistant" # Hope that the model manages to give the correct answer somewhere (GPT-5 really should) assert "Paris" in result.message_pieces[1].converted_value + + +@pytest.mark.asyncio +async def test_openai_responses_gpt5_json_schema(sqlite_instance, gpt5_args): + target = OpenAIResponseTarget(**gpt5_args) + + conv_id = str(uuid.uuid4()) + + developer_piece = MessagePiece( + role="developer", + original_value="You are an expert in the lore of cats.", + original_value_data_type="text", + conversation_id=conv_id, + attack_identifier={"id": str(uuid.uuid4())}, + ) + sqlite_instance.add_message_to_memory(request=developer_piece.to_message()) + + cat_schema = { + "type": "object", + "properties": { + "name": {"type": "string", "minLength": 12}, + "age": {"type": "integer", "minimum": 0, "maximum": 20}, + "colour": { + "type": "array", + "items": {"type": "integer", "minimum": 0, "maximum": 255}, + "minItems": 3, + "maxItems": 3, + }, + }, + "required": ["name", "age", "colour"], + "additionalProperties": False, + } + + prompt = "Create a JSON object that describes a mystical cat " + prompt += "with the following properties: name, age, colour." + + user_piece = MessagePiece( + role="user", + original_value=prompt, + original_value_data_type="text", + conversation_id=conv_id, + prompt_metadata={"response_format": "json", "json_schema": json.dumps(cat_schema)}, + ) + + response = await target.send_prompt_async(message=user_piece.to_message()) + + response_content = response.get_value(1) + response_json = json.loads(response_content) + jsonschema.validate(instance=response_json, schema=cat_schema) + + +@pytest.mark.asyncio +async def test_openai_responses_gpt5_json_object(sqlite_instance, gpt5_args): + target = OpenAIResponseTarget(**gpt5_args) + + conv_id = str(uuid.uuid4()) + + developer_piece = MessagePiece( + role="developer", + original_value="You are an expert in the lore of cats.", + original_value_data_type="text", + conversation_id=conv_id, + attack_identifier={"id": str(uuid.uuid4())}, + ) + + sqlite_instance.add_message_to_memory(request=developer_piece.to_message()) + + prompt = "Create a JSON object that describes a mystical cat " + prompt += "with the following properties: name, age, colour." + + user_piece = MessagePiece( + role="user", + original_value=prompt, + original_value_data_type="text", + conversation_id=conv_id, + prompt_metadata={"response_format": "json"}, + ) + response = await target.send_prompt_async(message=user_piece.to_message()) + + response_content = response.get_value(1) + response_json = json.loads(response_content) + assert response_json is not None + # Can't assert more, since the failure could be due to a bad generation by the model diff --git a/tests/unit/models/test_json_response_config.py b/tests/unit/models/test_json_response_config.py new file mode 100644 index 000000000..f715907ab --- /dev/null +++ b/tests/unit/models/test_json_response_config.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json + +import pytest + +from pyrit.models import JsonResponseConfig + + +def test_with_none(): + config = JsonResponseConfig.from_metadata(metadata=None) + assert config.enabled is False + assert config.schema is None + assert config.schema_name == "CustomSchema" + assert config.strict is True + + +def test_with_json_object(): + metadata = { + "response_format": "json", + } + config = JsonResponseConfig.from_metadata(metadata=metadata) + assert config.enabled is True + assert config.schema is None + assert config.schema_name == "CustomSchema" + assert config.strict is True + + +def test_with_json_string_schema(): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + metadata = { + "response_format": "json", + "json_schema": json.dumps(schema), + "schema_name": "TestSchema", + "strict": False, + } + config = JsonResponseConfig.from_metadata(metadata=metadata) + assert config.enabled is True + assert config.schema == schema + assert config.schema_name == "TestSchema" + assert config.strict is False + + +def test_with_json_schema_object(): + schema = {"type": "object", "properties": {"age": {"type": "integer"}}} + metadata = { + "response_format": "json", + "json_schema": schema, + } + config = JsonResponseConfig.from_metadata(metadata=metadata) + assert config.enabled is True + assert config.schema == schema + assert config.schema_name == "CustomSchema" + assert config.strict is True + + +def test_with_invalid_json_schema_string(): + metadata = { + "response_format": "json", + "json_schema": "{invalid_json: true}", + } + with pytest.raises(ValueError) as e: + JsonResponseConfig.from_metadata(metadata=metadata) + assert "Invalid JSON schema provided" in str(e.value) + + +def test_other_response_format(): + metadata = { + "response_format": "something_really_improbably_to_have_here", + } + config = JsonResponseConfig.from_metadata(metadata=metadata) + assert config.enabled is False + assert config.schema is None + assert config.schema_name == "CustomSchema" + assert config.strict is True diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index c01701a19..8e02ac41e 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -23,7 +23,7 @@ RateLimitException, ) from pyrit.memory.memory_interface import MemoryInterface -from pyrit.models import Message, MessagePiece +from pyrit.models import JsonResponseConfig, Message, MessagePiece from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget @@ -126,7 +126,6 @@ def test_init_is_json_supported_can_be_set_to_true(patch_central_database): @pytest.mark.asyncio() async def test_build_chat_messages_for_multi_modal(target: OpenAIChatTarget): - image_request = get_image_message_piece() entries = [ Message( @@ -178,23 +177,31 @@ async def test_construct_request_body_includes_extra_body_params( request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert body["key"] == "value" @pytest.mark.asyncio -@pytest.mark.parametrize("is_json", [True, False]) -async def test_construct_request_body_includes_json( - is_json, target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece -): +async def test_construct_request_body_json_object(target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece): + request = Message(message_pieces=[dummy_text_message_piece]) + jrc = JsonResponseConfig.from_metadata(metadata={"response_format": "json"}) + + body = await target._construct_request_body(conversation=[request], json_config=jrc) + assert body["response_format"] == {"type": "json_object"} + +@pytest.mark.asyncio +async def test_construct_request_body_json_schema(target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece): + schema_obj = {"type": "object", "properties": {"name": {"type": "string"}}} request = Message(message_pieces=[dummy_text_message_piece]) + jrc = JsonResponseConfig.from_metadata(metadata={"response_format": "json", "json_schema": schema_obj}) - body = await target._construct_request_body(conversation=[request], is_json_response=is_json) - if is_json: - assert body["response_format"] == {"type": "json_object"} - else: - assert "response_format" not in body + body = await target._construct_request_body(conversation=[request], json_config=jrc) + assert body["response_format"] == { + "type": "json_schema", + "json_schema": {"name": "CustomSchema", "schema": schema_obj, "strict": True}, + } @pytest.mark.asyncio @@ -203,13 +210,15 @@ async def test_construct_request_body_removes_empty_values( ): request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert "max_completion_tokens" not in body assert "max_tokens" not in body assert "temperature" not in body assert "top_p" not in body assert "frequency_penalty" not in body assert "presence_penalty" not in body + assert "response_format" not in body @pytest.mark.asyncio @@ -217,8 +226,9 @@ async def test_construct_request_body_serializes_text_message( target: OpenAIChatTarget, dummy_text_message_piece: MessagePiece ): request = Message(message_pieces=[dummy_text_message_piece]) + jrc = JsonResponseConfig.from_metadata(metadata=None) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert ( body["messages"][0]["content"] == "dummy text" ), "Text messages are serialized in a simple way that's more broadly supported" @@ -231,8 +241,9 @@ async def test_construct_request_body_serializes_complex_message( image_piece = get_image_message_piece() image_piece.conversation_id = dummy_text_message_piece.conversation_id # Match conversation IDs request = Message(message_pieces=[dummy_text_message_piece, image_piece]) + jrc = JsonResponseConfig.from_metadata(metadata=None) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + body = await target._construct_request_body(conversation=[request], json_config=jrc) messages = body["messages"][0]["content"] assert len(messages) == 2, "Complex messages are serialized as a list" assert messages[0]["type"] == "text", "Text messages are serialized properly when multi-modal" @@ -314,7 +325,6 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory( side_effect = httpx.HTTPStatusError("Rate Limit Reached", response=response, request=MagicMock()) with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect): - message = Message(message_pieces=[MessagePiece(role="user", conversation_id="123", original_value="Hello")]) with pytest.raises(RateLimitException) as rle: @@ -437,7 +447,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di with patch( "pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock ) as mock_create: - openai_mock_return = MagicMock() openai_mock_return.text = json.dumps(openai_response_json) mock_create.return_value = openai_mock_return @@ -451,7 +460,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di @pytest.mark.asyncio async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIChatTarget): - message = Message(message_pieces=[MessagePiece(role="user", conversation_id="12345", original_value="Hello")]) response = MagicMock() @@ -462,7 +470,6 @@ async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIChat with patch( "pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect ) as mock_request: - with pytest.raises(RateLimitError): await target.send_prompt_async(message=message) assert mock_request.call_count == os.getenv("RETRY_MAX_NUM_ATTEMPTS") @@ -470,7 +477,6 @@ async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIChat @pytest.mark.asyncio async def test_send_prompt_async_bad_request_error(target: OpenAIChatTarget): - response = MagicMock() response.status_code = 400 @@ -486,7 +492,6 @@ async def test_send_prompt_async_bad_request_error(target: OpenAIChatTarget): @pytest.mark.asyncio async def test_send_prompt_async_content_filter_200(target: OpenAIChatTarget): - response_body = json.dumps( { "choices": [ @@ -522,7 +527,6 @@ async def test_send_prompt_async_content_filter_200(target: OpenAIChatTarget): def test_validate_request_unsupported_data_types(target: OpenAIChatTarget): - image_piece = get_image_message_piece() image_piece.converted_value_data_type = "new_unknown_type" # type: ignore message = Message( @@ -567,7 +571,6 @@ def test_inheritance_from_prompt_chat_target_base(): def test_is_response_format_json_supported(target: OpenAIChatTarget): - message_piece = MessagePiece( role="user", original_value="original prompt text", @@ -578,10 +581,28 @@ def test_is_response_format_json_supported(target: OpenAIChatTarget): ) result = target.is_response_format_json(message_piece) - + assert isinstance(result, bool) assert result is True +def test_is_response_format_json_schema_supported(target: OpenAIChatTarget): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + message_piece = MessagePiece( + role="user", + original_value="original prompt text", + converted_value="Hello, how are you?", + conversation_id="conversation_1", + sequence=0, + prompt_metadata={ + "response_format": "json", + "json_schema": json.dumps(schema), + }, + ) + + result = target.is_response_format_json(message_piece) + assert result + + def test_is_response_format_json_no_metadata(target: OpenAIChatTarget): message_piece = MessagePiece( role="user", @@ -649,7 +670,6 @@ async def test_send_prompt_async_calls_refresh_auth_headers(target: OpenAIChatTa patch.object(target, "_validate_request"), patch.object(target, "_construct_request_body", new_callable=AsyncMock) as mock_construct, ): - mock_construct.return_value = {} with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async") as mock_make_request: @@ -684,7 +704,6 @@ async def test_send_prompt_async_content_filter_400(target: OpenAIChatTarget): patch.object(target, "_validate_request"), patch.object(target, "_construct_request_body", new_callable=AsyncMock) as mock_construct, ): - mock_construct.return_value = {} error_json = {"error": {"code": "content_filter"}} @@ -751,7 +770,6 @@ def test_set_auth_headers_with_entra_auth(patch_central_database): patch("pyrit.prompt_target.openai.openai_target.get_default_scope") as mock_scope, patch("pyrit.prompt_target.openai.openai_target.AzureAuth") as mock_auth_class, ): - mock_scope.return_value = "https://cognitiveservices.azure.com/.default" mock_auth_instance = MagicMock() mock_auth_instance.get_token.return_value = "test_token_123" diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index 6d4275ee0..61d94974d 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -23,7 +23,7 @@ RateLimitException, ) from pyrit.memory.memory_interface import MemoryInterface -from pyrit.models import Message, MessagePiece +from pyrit.models import JsonResponseConfig, Message, MessagePiece from pyrit.prompt_target import OpenAIResponseTarget, PromptChatTarget @@ -87,7 +87,6 @@ def test_init_with_no_additional_request_headers_var_raises(): @pytest.mark.asyncio() async def test_build_input_for_multi_modal(target: OpenAIResponseTarget): - image_request = get_image_message_piece() conversation_id = image_request.conversation_id entries = [ @@ -165,23 +164,37 @@ async def test_construct_request_body_includes_extra_body_params( request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert body["key"] == "value" @pytest.mark.asyncio -@pytest.mark.parametrize("is_json", [True, False]) -async def test_construct_request_body_includes_json( - is_json, target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece -): +async def test_construct_request_body_json_object(target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece): + json_response_config = JsonResponseConfig(enabled=True) + request = Message(message_pieces=[dummy_text_message_piece]) + + body = await target._construct_request_body(conversation=[request], json_config=json_response_config) + assert body["text"] == {"format": {"type": "json_object"}} + +@pytest.mark.asyncio +async def test_construct_request_body_json_schema(target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece): + schema_object = {"type": "object", "properties": {"name": {"type": "string"}}} + json_response_config = JsonResponseConfig.from_metadata( + metadata={"response_format": "json", "json_schema": schema_object} + ) request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=is_json) - if is_json: - assert body["response_format"] == {"type": "json_object"} - else: - assert "response_format" not in body + body = await target._construct_request_body(conversation=[request], json_config=json_response_config) + assert body["text"] == { + "format": { + "type": "json_schema", + "schema": schema_object, + "name": "CustomSchema", + "strict": True, + } + } @pytest.mark.asyncio @@ -190,13 +203,15 @@ async def test_construct_request_body_removes_empty_values( ): request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + json_response_config = JsonResponseConfig(enabled=False) + body = await target._construct_request_body(conversation=[request], json_config=json_response_config) assert "max_completion_tokens" not in body assert "max_tokens" not in body assert "temperature" not in body assert "top_p" not in body assert "frequency_penalty" not in body assert "presence_penalty" not in body + assert "text" not in body @pytest.mark.asyncio @@ -205,7 +220,8 @@ async def test_construct_request_body_serializes_text_message( ): request = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[request], json_config=jrc) assert body["input"][0]["content"][0]["text"] == "dummy text" @@ -213,13 +229,13 @@ async def test_construct_request_body_serializes_text_message( async def test_construct_request_body_serializes_complex_message( target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece ): - image_piece = get_image_message_piece() dummy_text_message_piece.conversation_id = image_piece.conversation_id request = Message(message_pieces=[dummy_text_message_piece, image_piece]) + jrc = JsonResponseConfig.from_metadata(metadata=None) - body = await target._construct_request_body(conversation=[request], is_json_response=False) + body = await target._construct_request_body(conversation=[request], json_config=jrc) messages = body["input"][0]["content"] assert len(messages) == 2 assert messages[0]["type"] == "input_text" @@ -303,7 +319,6 @@ async def test_send_prompt_async_rate_limit_exception_adds_to_memory( side_effect = httpx.HTTPStatusError("Rate Limit Reached", response=response, request=MagicMock()) with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect): - message = Message(message_pieces=[MessagePiece(role="user", conversation_id="123", original_value="Hello")]) with pytest.raises(RateLimitException) as rle: @@ -426,7 +441,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di with patch( "pyrit.common.net_utility.make_request_and_raise_if_error_async", new_callable=AsyncMock ) as mock_create: - openai_mock_return = MagicMock() openai_mock_return.text = json.dumps(openai_response_json) mock_create.return_value = openai_mock_return @@ -440,7 +454,6 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di @pytest.mark.asyncio async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIResponseTarget): - message = Message(message_pieces=[MessagePiece(role="user", conversation_id="12345", original_value="Hello")]) response = MagicMock() @@ -451,7 +464,6 @@ async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIResp with patch( "pyrit.common.net_utility.make_request_and_raise_if_error_async", side_effect=side_effect ) as mock_request: - with pytest.raises(RateLimitError): await target.send_prompt_async(message=message) assert mock_request.call_count == os.getenv("RETRY_MAX_NUM_ATTEMPTS") @@ -459,7 +471,6 @@ async def test_send_prompt_async_rate_limit_exception_retries(target: OpenAIResp @pytest.mark.asyncio async def test_send_prompt_async_bad_request_error(target: OpenAIResponseTarget): - response = MagicMock() response.status_code = 400 @@ -475,7 +486,6 @@ async def test_send_prompt_async_bad_request_error(target: OpenAIResponseTarget) @pytest.mark.asyncio async def test_send_prompt_async_content_filter(target: OpenAIResponseTarget): - response_body = json.dumps( { "error": { @@ -511,7 +521,6 @@ async def test_send_prompt_async_content_filter(target: OpenAIResponseTarget): def test_validate_request_unsupported_data_types(target: OpenAIResponseTarget): - image_piece = get_image_message_piece() image_piece.converted_value_data_type = "new_unknown_type" # type: ignore message = Message( @@ -544,7 +553,6 @@ def test_inheritance_from_prompt_chat_target(target: OpenAIResponseTarget): def test_is_response_format_json_supported(target: OpenAIResponseTarget): - message_piece = MessagePiece( role="user", original_value="original prompt text", @@ -556,9 +564,28 @@ def test_is_response_format_json_supported(target: OpenAIResponseTarget): result = target.is_response_format_json(message_piece) + assert isinstance(result, bool) assert result is True +def test_is_response_format_json_schema_supported(target: OpenAIResponseTarget): + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + message_piece = MessagePiece( + role="user", + original_value="original prompt text", + converted_value="Hello, how are you?", + conversation_id="conversation_1", + sequence=0, + prompt_metadata={ + "response_format": "json", + "json_schema": json.dumps(schema), + }, + ) + + result = target.is_response_format_json(message_piece) + assert result + + def test_is_response_format_json_no_metadata(target: OpenAIResponseTarget): message_piece = MessagePiece( role="user", @@ -619,7 +646,6 @@ async def test_send_prompt_async_calls_refresh_auth_headers(target: OpenAIRespon patch.object(target, "_validate_request"), patch.object(target, "_construct_request_body", new_callable=AsyncMock) as mock_construct, ): - mock_construct.return_value = {} with patch("pyrit.common.net_utility.make_request_and_raise_if_error_async") as mock_make_request: @@ -761,7 +787,8 @@ async def test_construct_request_body_filters_none( target: OpenAIResponseTarget, dummy_text_message_piece: MessagePiece ): req = Message(message_pieces=[dummy_text_message_piece]) - body = await target._construct_request_body([req], is_json_response=False) + jrc = JsonResponseConfig.from_metadata(metadata=None) + body = await target._construct_request_body(conversation=[req], json_config=jrc) assert "max_output_tokens" not in body or body["max_output_tokens"] is None assert "temperature" not in body or body["temperature"] is None assert "top_p" not in body or body["top_p"] is None