diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 2a3752c370..99f8344ae5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -11,6 +11,7 @@ from typing_extensions import assert_never from .. import ModelHTTPError, UnexpectedModelBehavior, _utils +from .._output import OutputObjectDefinition from .._run_context import RunContext from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime from ..exceptions import ModelAPIError, UserError @@ -62,6 +63,7 @@ Mistral, OptionalNullable as MistralOptionalNullable, ReferenceChunk as MistralReferenceChunk, + ResponseFormat as MistralResponseFormat, TextChunk as MistralTextChunk, ThinkChunk as MistralThinkChunk, ToolChoiceEnum as MistralToolChoiceEnum, @@ -70,6 +72,7 @@ ChatCompletionResponse as MistralChatCompletionResponse, CompletionEvent as MistralCompletionEvent, FinishReason as MistralFinishReason, + JSONSchema as MistralJSONSchema, Messages as MistralMessages, SDKError, Tool as MistralTool, @@ -215,6 +218,32 @@ async def request_stream( async with response: yield await self._process_streamed_response(response, model_request_parameters) + def _get_response_format(self, model_request_parameters: ModelRequestParameters) -> MistralResponseFormat | None: + """Get the response format for Mistral API based on output mode. + + Returns None if no special format is needed. + """ + if model_request_parameters.output_mode == 'native': + # Use native JSON schema mode + output_object = model_request_parameters.output_object + assert output_object is not None + json_schema = self._map_json_schema(output_object) + return MistralResponseFormat(type='json_schema', json_schema=json_schema) + elif model_request_parameters.output_mode == 'prompted' and not model_request_parameters.function_tools: + # Use JSON object mode (without schema) + return MistralResponseFormat(type='json_object') + else: + return None + + def _map_json_schema(self, o: OutputObjectDefinition) -> MistralJSONSchema: + """Map OutputObjectDefinition to Mistral JSONSchema format.""" + return MistralJSONSchema( + name=o.name or 'output', + schema_definition=o.json_schema, + description=o.description or UNSET, + strict=o.strict if o.strict is not None else None, + ) + async def _completions_create( self, messages: list[ModelMessage], @@ -227,13 +256,25 @@ async def _completions_create( if model_request_parameters.builtin_tools: raise UserError('Mistral does not support built-in tools') + # Determine the response format based on output mode + response_format = self._get_response_format(model_request_parameters) + + # When using native JSON schema mode, don't use tool-based output + if model_request_parameters.output_mode == 'native': + tools = self._map_function_tools_only(model_request_parameters) or UNSET + tool_choice = self._get_tool_choice_for_functions_only(model_request_parameters) + else: + tools = self._map_function_and_output_tools_definition(model_request_parameters) or UNSET + tool_choice = self._get_tool_choice(model_request_parameters) + try: response = await self.client.chat.complete_async( model=str(self._model_name), messages=self._map_messages(messages, model_request_parameters), n=1, - tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, - tool_choice=self._get_tool_choice(model_request_parameters), + tools=tools, + tool_choice=tool_choice, + response_format=response_format, stream=False, max_tokens=model_settings.get('max_tokens', UNSET), temperature=model_settings.get('temperature', UNSET), @@ -258,57 +299,41 @@ async def _stream_completions_create( model_request_parameters: ModelRequestParameters, ) -> MistralEventStreamAsync[MistralCompletionEvent]: """Create a streaming completion request to the Mistral model.""" - response: MistralEventStreamAsync[MistralCompletionEvent] | None - mistral_messages = self._map_messages(messages, model_request_parameters) - # TODO(Marcelo): We need to replace the current MistralAI client to use the beta client. # See https://docs.mistral.ai/agents/connectors/websearch/ to support web search. if model_request_parameters.builtin_tools: raise UserError('Mistral does not support built-in tools') - if model_request_parameters.function_tools: - # Function Calling - response = await self.client.chat.stream_async( - model=str(self._model_name), - messages=mistral_messages, - n=1, - tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, - tool_choice=self._get_tool_choice(model_request_parameters), - temperature=model_settings.get('temperature', UNSET), - top_p=model_settings.get('top_p', 1), - max_tokens=model_settings.get('max_tokens', UNSET), - timeout_ms=self._get_timeout_ms(model_settings.get('timeout')), - presence_penalty=model_settings.get('presence_penalty'), - frequency_penalty=model_settings.get('frequency_penalty'), - stop=model_settings.get('stop_sequences', None), - http_headers={'User-Agent': get_user_agent()}, - ) - - elif model_request_parameters.output_tools: - # TODO: Port to native "manual JSON" mode - # Json Mode - parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools] - user_output_format_message = self._generate_user_output_format(parameters_json_schemas) - mistral_messages.append(user_output_format_message) + mistral_messages = self._map_messages(messages, model_request_parameters) - response = await self.client.chat.stream_async( - model=str(self._model_name), - messages=mistral_messages, - response_format={ - 'type': 'json_object' - }, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9 - stream=True, - http_headers={'User-Agent': get_user_agent()}, - ) + # Determine the response format based on output mode + response_format = self._get_response_format(model_request_parameters) + # When using native JSON schema mode, don't use tool-based output + if model_request_parameters.output_mode == 'native': + tools = self._map_function_tools_only(model_request_parameters) or UNSET + tool_choice = self._get_tool_choice_for_functions_only(model_request_parameters) else: - # Stream Mode - response = await self.client.chat.stream_async( - model=str(self._model_name), - messages=mistral_messages, - stream=True, - http_headers={'User-Agent': get_user_agent()}, - ) + tools = self._map_function_and_output_tools_definition(model_request_parameters) or UNSET + tool_choice = self._get_tool_choice(model_request_parameters) + + response = await self.client.chat.stream_async( + model=str(self._model_name), + messages=mistral_messages, + n=1, + tools=tools, + tool_choice=tool_choice, + response_format=response_format, + temperature=model_settings.get('temperature', UNSET), + top_p=model_settings.get('top_p', 1), + max_tokens=model_settings.get('max_tokens', UNSET), + timeout_ms=self._get_timeout_ms(model_settings.get('timeout')), + presence_penalty=model_settings.get('presence_penalty'), + frequency_penalty=model_settings.get('frequency_penalty'), + stop=model_settings.get('stop_sequences', None), + stream=True, + http_headers={'User-Agent': get_user_agent()}, + ) assert response, 'A unexpected empty response from Mistral.' return response @@ -344,6 +369,30 @@ def _map_function_and_output_tools_definition( ] return tools if tools else None + def _map_function_tools_only(self, model_request_parameters: ModelRequestParameters) -> list[MistralTool] | None: + """Map only function tools (not output tools) to MistralTool format. + + This is used when output is handled via native JSON schema mode instead of tools. + """ + tools = [ + MistralTool( + function=MistralFunction( + name=r.name, parameters=r.parameters_json_schema, description=r.description or '' + ) + ) + for r in model_request_parameters.function_tools + ] + return tools if tools else None + + def _get_tool_choice_for_functions_only( + self, model_request_parameters: ModelRequestParameters + ) -> MistralToolChoiceEnum | None: + """Get tool choice when only function tools are used (not output tools).""" + if not model_request_parameters.function_tools: + return None + # When using native output mode, we don't force tool use since output is handled separately + return 'auto' + def _process_response(self, response: MistralChatCompletionResponse) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" assert response.choices, 'Unexpected empty response choice.' diff --git a/pydantic_ai_slim/pydantic_ai/profiles/mistral.py b/pydantic_ai_slim/pydantic_ai/profiles/mistral.py index 438cc95833..b2a99faffa 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/mistral.py @@ -1,8 +1,13 @@ from __future__ import annotations as _annotations from . import ModelProfile +from .openai import OpenAIJsonSchemaTransformer def mistral_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Mistral model.""" - return None + return ModelProfile( + json_schema_transformer=OpenAIJsonSchemaTransformer, + supports_json_schema_output=True, + supports_json_object_output=True, + ) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 32a56c6395..b92de4bb02 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -29,6 +29,7 @@ ) from pydantic_ai.agent import Agent from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry +from pydantic_ai.output import NativeOutput, PromptedOutput from pydantic_ai.usage import RequestUsage from ..conftest import IsDatetime, IsNow, IsStr, raise_if_exception, try_import @@ -2345,3 +2346,221 @@ async def test_mistral_model_thinking_part_iter(allow_model_requests: None, mist ), ] ) + + +##################### +## Native Output +##################### + + +async def test_mistral_native_output(allow_model_requests: None): + """Test Mistral native JSON schema output with a simple model.""" + + class CityLocation(BaseModel): + """A city and its country.""" + + city: str + country: str + + completion = completion_message( + MistralAssistantMessage( + content='{"city": "Mexico City", "country": "Mexico"}', + role='assistant', + ), + usage=MistralUsageInfo(prompt_tokens=10, completion_tokens=15, total_tokens=25), + ) + mock_client = MockMistralAI.create_mock(completion) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + agent = Agent(model, output_type=NativeOutput(CityLocation)) + + result = await agent.run('What is the largest city in Mexico?') + + assert result.output == CityLocation(city='Mexico City', country='Mexico') + assert result.usage().input_tokens == 10 + assert result.usage().output_tokens == 15 + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsNow(tz=timezone.utc), + ) + ], + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], + usage=RequestUsage(input_tokens=10, output_tokens=15), + model_name='mistral-large-123', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + provider_name='mistral', + provider_details={'finish_reason': 'stop'}, + provider_response_id='123', + finish_reason='stop', + run_id=IsStr(), + ), + ] + ) + + +async def test_mistral_native_output_multiple_schemas(allow_model_requests: None): + """Test Mistral native JSON schema output with multiple possible schemas.""" + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + completion = completion_message( + MistralAssistantMessage( + content='{"result": {"kind": "CountryLanguage", "data": {"country": "Mexico", "language": "Spanish"}}}', + role='assistant', + ), + usage=MistralUsageInfo(prompt_tokens=12, completion_tokens=18, total_tokens=30), + ) + mock_client = MockMistralAI.create_mock(completion) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + agent = Agent(model, output_type=NativeOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the primary language spoken in Mexico?') + + assert result.output == CountryLanguage(country='Mexico', language='Spanish') + assert result.usage().input_tokens == 12 + assert result.usage().output_tokens == 18 + + +async def test_mistral_native_output_streaming(allow_model_requests: None): + """Test Mistral native JSON schema output with streaming.""" + + class CityLocation(BaseModel): + """A city and its country.""" + + city: str + country: str + + stream_events = [ + text_chunk('{"city": '), + text_chunk('"Paris", '), + text_chunk('"country": '), + text_chunk('"France"}', finish_reason='stop'), + ] + mock_client = MockMistralAI.create_stream_mock(stream_events) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + agent = Agent(model, output_type=NativeOutput(CityLocation)) + + async with agent.run_stream('What is the capital of France?') as result: + assert not result.is_complete + outputs = [output async for output in result.stream_output(debounce_by=None)] + + assert result.is_complete + final_output = await result.get_output() + assert final_output == CityLocation(city='Paris', country='France') + # Verify that partial outputs were streamed + assert len(outputs) > 0 + assert outputs[-1] == CityLocation(city='Paris', country='France') + + +async def test_mistral_native_output_with_nested_model(allow_model_requests: None): + """Test Mistral native JSON schema output with nested models.""" + + class Address(BaseModel): + street: str + city: str + + class Person(BaseModel): + name: str + address: Address + + completion = completion_message( + MistralAssistantMessage( + content='{"name": "John Doe", "address": {"street": "123 Main St", "city": "Paris"}}', + role='assistant', + ), + usage=MistralUsageInfo(prompt_tokens=15, completion_tokens=25, total_tokens=40), + ) + mock_client = MockMistralAI.create_mock(completion) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + agent = Agent(model, output_type=NativeOutput(Person)) + + result = await agent.run('Create a person') + + assert result.output == Person(name='John Doe', address=Address(street='123 Main St', city='Paris')) + assert result.usage().input_tokens == 15 + assert result.usage().output_tokens == 25 + + +async def test_prompted_output_json_object_mode(allow_model_requests: None): + """Test prompted mode with json_object response format (no function tools).""" + + class CityInfo(BaseModel): + city: str + population: int + + completion = completion_message( + MistralAssistantMessage( + content='{"city": "Tokyo", "population": 14000000}', + role='assistant', + ), + usage=MistralUsageInfo(prompt_tokens=10, completion_tokens=15, total_tokens=25), + ) + mock_client = MockMistralAI.create_mock(completion) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + # PromptedOutput uses 'prompted' output mode without function tools + agent = Agent(model, output_type=PromptedOutput(CityInfo)) + + result = await agent.run('Tell me about Tokyo') + + assert result.output == CityInfo(city='Tokyo', population=14000000) + assert result.usage().input_tokens == 10 + assert result.usage().output_tokens == 15 + + +async def test_native_output_with_function_tools(allow_model_requests: None): + """Test native output mode with function tools present (tool_choice should be 'auto').""" + + class CityLocation(BaseModel): + city: str + country: str + + completion = [ + # First response: call the tool + completion_message( + MistralAssistantMessage( + content=None, + role='assistant', + tool_calls=[ + MistralToolCall( + id='1', + function=MistralFunctionCall(arguments='{"city_name": "Paris"}', name='get_coordinates'), + type='function', + ) + ], + ), + usage=MistralUsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ), + # Second response: return native output + completion_message( + MistralAssistantMessage( + content='{"city": "Paris", "country": "France"}', + role='assistant', + ), + usage=MistralUsageInfo(prompt_tokens=15, completion_tokens=10, total_tokens=25), + ), + ] + mock_client = MockMistralAI.create_mock(completion) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + agent = Agent(model, output_type=NativeOutput(CityLocation)) + + @agent.tool_plain + async def get_coordinates(city_name: str) -> str: + return '{"lat": 48.8566, "lng": 2.3522}' + + result = await agent.run('Tell me about Paris') + + assert result.output == CityLocation(city='Paris', country='France') + assert result.usage().input_tokens == 25 + assert result.usage().output_tokens == 15 diff --git a/tests/providers/test_groq.py b/tests/providers/test_groq.py index 307c8101f2..178c9f4877 100644 --- a/tests/providers/test_groq.py +++ b/tests/providers/test_groq.py @@ -119,7 +119,7 @@ def test_groq_provider_model_profile(mocker: MockerFixture): mistral_profile = provider.model_profile('mistral-saba-24b') mistral_model_profile_mock.assert_called_with('mistral-saba-24b') - assert mistral_profile is None + assert mistral_profile is not None qwen_profile = provider.model_profile('qwen-qwq-32b') qwen_model_profile_mock.assert_called_with('qwen-qwq-32b') diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 50c1c1e9e6..68df4ad1c3 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -178,7 +178,7 @@ def test_huggingface_provider_model_profile(mocker: MockerFixture): mistral_profile = provider.model_profile('mistralai/Devstral-Small-2505') mistral_model_profile_mock.assert_called_with('devstral-small-2505') - assert mistral_profile is None + assert mistral_profile is not None google_profile = provider.model_profile('google/gemma-3-27b-it') google_model_profile_mock.assert_called_with('gemma-3-27b-it')