diff --git a/examples/pydantic_ai_examples/weather_agent.py b/examples/pydantic_ai_examples/weather_agent.py index f02cf854f6..1fbc28300f 100644 --- a/examples/pydantic_ai_examples/weather_agent.py +++ b/examples/pydantic_ai_examples/weather_agent.py @@ -13,7 +13,6 @@ import asyncio from dataclasses import dataclass -from typing import Any import logfire from httpx import AsyncClient @@ -46,6 +45,11 @@ class LatLng(BaseModel): lng: float +class WeatherResponse(BaseModel): + temperature: str + description: str + + @weather_agent.tool async def get_lat_lng(ctx: RunContext[Deps], location_description: str) -> LatLng: """Get the latitude and longitude of a location. @@ -64,7 +68,7 @@ async def get_lat_lng(ctx: RunContext[Deps], location_description: str) -> LatLn @weather_agent.tool -async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]: +async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> WeatherResponse: """Get the weather at a location. Args: @@ -85,10 +89,10 @@ async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str ) temp_response.raise_for_status() descr_response.raise_for_status() - return { - 'temperature': f'{temp_response.text} °C', - 'description': descr_response.text, - } + return WeatherResponse( + temperature=f'{temp_response.text} °C', + description=descr_response.text, + ) async def main(): diff --git a/examples/pydantic_ai_examples/weather_agent_gradio.py b/examples/pydantic_ai_examples/weather_agent_gradio.py index 0a1163ffbd..a78257c495 100755 --- a/examples/pydantic_ai_examples/weather_agent_gradio.py +++ b/examples/pydantic_ai_examples/weather_agent_gradio.py @@ -10,18 +10,39 @@ try: import gradio as gr + from gradio_client import utils as gradio_utils except ImportError as e: raise ImportError( 'Please install gradio with `pip install gradio`. You must use python>=3.10.' ) from e +# Monkey patch to fix Gradio's JSON schema parser for boolean additionalProperties +_original_json_schema_to_python_type = gradio_utils._json_schema_to_python_type + + +def _patched_json_schema_to_python_type(schema, defs): + """Handle boolean additionalProperties in JSON schemas.""" + if isinstance(schema, bool): + return 'Any' + if isinstance(schema, dict) and isinstance(schema.get('additionalProperties'), bool): + schema = schema.copy() + if schema['additionalProperties']: + schema['additionalProperties'] = {} + else: + schema.pop('additionalProperties', None) + return _original_json_schema_to_python_type(schema, defs) + + +gradio_utils._json_schema_to_python_type = _patched_json_schema_to_python_type + TOOL_TO_DISPLAY_NAME = {'get_lat_lng': 'Geocoding API', 'get_weather': 'Weather API'} client = AsyncClient() deps = Deps(client=client) -async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list): +async def stream_from_agent(prompt, chatbot, past_messages): + """Stream agent responses with tool calls to Gradio chatbot.""" chatbot.append({'role': 'user', 'content': prompt}) yield gr.Textbox(interactive=False, value=''), chatbot, gr.skip() async with weather_agent.run_stream( @@ -64,7 +85,8 @@ async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: lis yield gr.Textbox(interactive=True), gr.skip(), past_messages -async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData): +async def handle_retry(chatbot, past_messages, retry_data: gr.RetryData): + """Handle retry events from the chatbot.""" new_history = chatbot[: retry_data.index] previous_prompt = chatbot[retry_data.index]['content'] past_messages = past_messages[: retry_data.index] @@ -72,13 +94,15 @@ async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData): yield update -def undo(chatbot, past_messages: list, undo_data: gr.UndoData): +def undo(chatbot, past_messages, undo_data: gr.UndoData): + """Handle undo events from the chatbot.""" new_history = chatbot[: undo_data.index] past_messages = past_messages[: undo_data.index] return chatbot[undo_data.index]['content'], new_history, past_messages -def select_data(message: gr.SelectData) -> str: +def select_data(message: gr.SelectData): + """Handle example selection from the chatbot.""" return message.value['text'] @@ -116,6 +140,7 @@ def select_data(message: gr.SelectData) -> str: stream_from_agent, inputs=[prompt, chatbot, past_messages], outputs=[prompt, chatbot, past_messages], + api_name=False, ) chatbot.example_select(select_data, None, [prompt]) chatbot.retry(