Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions examples/pydantic_ai_examples/weather_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import asyncio
from dataclasses import dataclass
from typing import Any

import logfire
from httpx import AsyncClient
Expand Down Expand Up @@ -46,6 +45,11 @@ class LatLng(BaseModel):
lng: float


class WeatherResponse(BaseModel):
temperature: str
description: str


@weather_agent.tool
async def get_lat_lng(ctx: RunContext[Deps], location_description: str) -> LatLng:
"""Get the latitude and longitude of a location.
Expand All @@ -64,7 +68,7 @@ async def get_lat_lng(ctx: RunContext[Deps], location_description: str) -> LatLn


@weather_agent.tool
async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]:
async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> WeatherResponse:
"""Get the weather at a location.

Args:
Expand All @@ -85,10 +89,10 @@ async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str
)
temp_response.raise_for_status()
descr_response.raise_for_status()
return {
'temperature': f'{temp_response.text} °C',
'description': descr_response.text,
}
return WeatherResponse(
temperature=f'{temp_response.text} °C',
description=descr_response.text,
)


async def main():
Expand Down
33 changes: 29 additions & 4 deletions examples/pydantic_ai_examples/weather_agent_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,39 @@

try:
import gradio as gr
from gradio_client import utils as gradio_utils
except ImportError as e:
raise ImportError(
'Please install gradio with `pip install gradio`. You must use python>=3.10.'
) from e

# Monkey patch to fix Gradio's JSON schema parser for boolean additionalProperties
_original_json_schema_to_python_type = gradio_utils._json_schema_to_python_type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be fixed in Gradio instead?

If using a BaseModel fixes it, we don't need this here right?



def _patched_json_schema_to_python_type(schema, defs):
"""Handle boolean additionalProperties in JSON schemas."""
if isinstance(schema, bool):
return 'Any'
if isinstance(schema, dict) and isinstance(schema.get('additionalProperties'), bool):
schema = schema.copy()
if schema['additionalProperties']:
schema['additionalProperties'] = {}
else:
schema.pop('additionalProperties', None)
return _original_json_schema_to_python_type(schema, defs)


gradio_utils._json_schema_to_python_type = _patched_json_schema_to_python_type

TOOL_TO_DISPLAY_NAME = {'get_lat_lng': 'Geocoding API', 'get_weather': 'Weather API'}

client = AsyncClient()
deps = Deps(client=client)


async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list):
async def stream_from_agent(prompt, chatbot, past_messages):
"""Stream agent responses with tool calls to Gradio chatbot."""
chatbot.append({'role': 'user', 'content': prompt})
yield gr.Textbox(interactive=False, value=''), chatbot, gr.skip()
async with weather_agent.run_stream(
Expand Down Expand Up @@ -64,21 +85,24 @@ async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: lis
yield gr.Textbox(interactive=True), gr.skip(), past_messages


async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData):
async def handle_retry(chatbot, past_messages, retry_data: gr.RetryData):
"""Handle retry events from the chatbot."""
new_history = chatbot[: retry_data.index]
previous_prompt = chatbot[retry_data.index]['content']
past_messages = past_messages[: retry_data.index]
async for update in stream_from_agent(previous_prompt, new_history, past_messages):
yield update


def undo(chatbot, past_messages: list, undo_data: gr.UndoData):
def undo(chatbot, past_messages, undo_data: gr.UndoData):
"""Handle undo events from the chatbot."""
new_history = chatbot[: undo_data.index]
past_messages = past_messages[: undo_data.index]
return chatbot[undo_data.index]['content'], new_history, past_messages


def select_data(message: gr.SelectData) -> str:
def select_data(message: gr.SelectData):
"""Handle example selection from the chatbot."""
return message.value['text']


Expand Down Expand Up @@ -116,6 +140,7 @@ def select_data(message: gr.SelectData) -> str:
stream_from_agent,
inputs=[prompt, chatbot, past_messages],
outputs=[prompt, chatbot, past_messages],
api_name=False,
)
chatbot.example_select(select_data, None, [prompt])
chatbot.retry(
Expand Down
Loading