diff --git a/docs/agents.md b/docs/agents.md index ab3f658b7a..b5640c2a1d 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -65,7 +65,7 @@ There are five ways to run an agent: 1. [`agent.run()`][pydantic_ai.agent.AbstractAgent.run] — an async function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response. 2. [`agent.run_sync()`][pydantic_ai.agent.AbstractAgent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). -3. [`agent.run_stream()`][pydantic_ai.agent.AbstractAgent.run_stream] — an async context manager which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream text and structured output as an async iterable. +3. [`agent.run_stream()`][pydantic_ai.agent.AbstractAgent.run_stream] — an async context manager which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream text and structured output as an async iterable. [`agent.run_stream_sync()`][pydantic_ai.agent.AbstractAgent.run_stream_sync] is a synchronous context manager variation with the same return type. 4. [`agent.run_stream_events()`][pydantic_ai.agent.AbstractAgent.run_stream_events] — a function which returns an async iterable of [`AgentStreamEvent`s][pydantic_ai.messages.AgentStreamEvent] and a [`AgentRunResultEvent`][pydantic_ai.run.AgentRunResultEvent] containing the final run result. 5. [`agent.iter()`][pydantic_ai.Agent.iter] — a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun], an async iterable over the nodes of the agent's underlying [`Graph`][pydantic_graph.graph.Graph]. diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 43cd4fa749..f1e10b49d1 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -4,7 +4,7 @@ import inspect from abc import ABC, abstractmethod from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterator, Mapping, Sequence -from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager +from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager, contextmanager from types import FrameType from typing import TYPE_CHECKING, Any, Generic, TypeAlias, cast, overload @@ -581,6 +581,115 @@ async def on_complete() -> None: if not yielded: raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover + @overload + def run_stream_sync( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: None = None, + message_history: Sequence[_messages.ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> AbstractContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... + + @overload + def run_stream_sync( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT], + message_history: Sequence[_messages.ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> AbstractContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... + + @contextmanager + def run_stream_sync( + self, + user_prompt: str | Sequence[_messages.UserContent] | None = None, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + message_history: Sequence[_messages.ModelMessage] | None = None, + deferred_tool_results: DeferredToolResults | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.RunUsage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + builtin_tools: Sequence[AbstractBuiltinTool] | None = None, + event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + ) -> Iterator[result.StreamedRunResult[AgentDepsT, Any]]: + """Run the agent with a user prompt in sync streaming mode. + + This is a convenience method that wraps [`self.run_stream`][pydantic_ai.agent.AbstractAgent.run_stream] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. + + This method builds an internal agent graph (using system prompts, tools and output schemas) and then + runs the graph until the model produces output matching the `output_type`, for example text or structured data. + At this point, a streaming run result object is yielded from which you can stream the output as it comes in, + and -- once this output has completed streaming -- get the complete output, message history, and usage. + + As this method will consider the first output matching the `output_type` to be the final output, + it will stop running the agent graph and will not execute any tool calls made by the model after this "final" output. + If you want to always run the agent graph to completion and stream events and output at the same time, + use [`agent.run()`][pydantic_ai.agent.AbstractAgent.run] with an `event_stream_handler` or [`agent.iter()`][pydantic_ai.agent.AbstractAgent.iter] instead. + + Args: + user_prompt: User input to start/continue the conversation. + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + message_history: History of the conversation so far. + deferred_tool_results: Optional results for deferred tool calls in the message history. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + builtin_tools: Optional additional builtin tools for this run. + event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. + It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager. + Note that it does _not_ receive any events after the final result is found. + + Returns: + The result of the run. + """ + async_cm = self.run_stream( + user_prompt, + output_type=output_type, + message_history=message_history, + deferred_tool_results=deferred_tool_results, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + builtin_tools=builtin_tools, + event_stream_handler=event_stream_handler, + ) + yield get_event_loop().run_until_complete(async_cm.__aenter__()) + @overload def run_stream_events( self, diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index f5b542953e..d3e47d3788 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,6 +1,6 @@ from __future__ import annotations as _annotations -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime @@ -9,6 +9,8 @@ from pydantic import ValidationError from typing_extensions import TypeVar, deprecated +from pydantic_graph._utils import get_event_loop + from . import _utils, exceptions, messages as _messages, models from ._output import ( OutputDataT_inv, @@ -408,6 +410,27 @@ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterat else: raise ValueError('No stream response or run result provided') # pragma: no cover + def stream_output_sync(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]: + """Stream the output as an iterable. + + This is a convenience method that wraps [`self.stream_output`][pydantic_ai.result.StreamedRunResult.stream_output] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. + + The pydantic validator for structured data will be called in + [partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation) + on each iteration. + + Args: + debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing. + Debouncing is particularly important for long structured outputs to reduce the overhead of + performing validation as each token is received. + + Returns: + An iterable of the response data. + """ + async_stream = self.stream_output(debounce_by=debounce_by) + yield from _blocking_async_iterator(async_stream) + async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. @@ -436,6 +459,25 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = else: raise ValueError('No stream response or run result provided') # pragma: no cover + def stream_text_sync(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]: + """Stream the text result as a sync iterable. + + This is a convenience method that wraps [`self.stream_text`][pydantic_ai.result.StreamedRunResult.stream_text] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. + + !!! note + Result validators will NOT be called on the text result if `delta=True`. + + Args: + delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text + up to the current point. + debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. + Debouncing is particularly important for long structured responses to reduce the overhead of + performing validation as each token is received. + """ + async_stream = self.stream_text(delta=delta, debounce_by=debounce_by) + yield from _blocking_async_iterator(async_stream) + @deprecated('`StreamedRunResult.stream_structured` is deprecated, use `stream_responses` instead.') async def stream_structured( self, *, debounce_by: float | None = 0.1 @@ -471,6 +513,25 @@ async def stream_responses( else: raise ValueError('No stream response or run result provided') # pragma: no cover + def stream_responses_sync( + self, *, debounce_by: float | None = 0.1 + ) -> Iterator[tuple[_messages.ModelResponse, bool]]: + """Stream the response as an iterable of Structured LLM Messages. + + This is a convenience method that wraps [`self.stream_responses`][pydantic_ai.result.StreamedRunResult.stream_responses] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. + + Args: + debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. + Debouncing is particularly important for long structured responses to reduce the overhead of + performing validation as each token is received. + + Returns: + An iterable of the structured response message and whether that is the last message. + """ + async_stream = self.stream_responses(debounce_by=debounce_by) + yield from _blocking_async_iterator(async_stream) + async def get_output(self) -> OutputDataT: """Stream the whole response, validate and return it.""" if self._run_result is not None: @@ -484,6 +545,14 @@ async def get_output(self) -> OutputDataT: else: raise ValueError('No stream response or run result provided') # pragma: no cover + def get_output_sync(self) -> OutputDataT: + """Stream the whole response, validate and return it. + + This is a convenience method that wraps [`self.get_output`][pydantic_ai.result.StreamedRunResult.get_output] with `loop.run_until_complete(...)`. + You therefore can't use this method inside async code or if there's an active event loop. + """ + return get_event_loop().run_until_complete(self.get_output()) + @property def response(self) -> _messages.ModelResponse: """Return the current state of the response.""" @@ -559,6 +628,17 @@ class FinalResult(Generic[OutputDataT]): __repr__ = _utils.dataclasses_no_defaults_repr +def _blocking_async_iterator(async_iter: AsyncIterator[T]) -> Iterator[T]: + loop = get_event_loop() + + while True: + try: + item = loop.run_until_complete(async_iter.__anext__()) + yield item + except StopAsyncIteration: + break + + def _get_usage_checking_stream_response( stream_response: models.StreamedResponse, limits: UsageLimits | None, diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 537b76e03d..08fefed42e 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -134,6 +134,92 @@ async def ret_a(x: str) -> str: ) +@pytest.fixture +def close_cached_httpx_client(): + """Override the global fixture to avoid async context issues in sync tests.""" + yield + + +def test_streamed_text_sync_response(close_cached_httpx_client): # type: ignore[reportUnknownParameterType] + m = TestModel() + + test_agent = Agent(m) + assert test_agent.name is None + + @test_agent.tool_plain + async def ret_a(x: str) -> str: + return f'{x}-apple' + + with test_agent.run_stream_sync('Hello') as result: + # assert test_agent.name == 'test_agent' + assert not result.is_complete + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=51), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + ) + ] + ), + ] + ) + assert result.usage() == snapshot( + RunUsage( + requests=2, + input_tokens=103, + output_tokens=5, + tool_calls=1, + ) + ) + response = result.get_output_sync() + assert response == snapshot('{"ret_a":"a-apple"}') + assert result.is_complete + assert result.timestamp() == IsNow(tz=timezone.utc) + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())], + usage=RequestUsage(input_tokens=51), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], + usage=RequestUsage(input_tokens=52, output_tokens=11), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ] + ) + assert result.usage() == snapshot( + RunUsage( + requests=2, + input_tokens=103, + output_tokens=11, + tool_calls=1, + ) + ) + + async def test_streamed_structured_response(): m = TestModel() @@ -302,6 +388,120 @@ def upcase(text: str) -> str: ) +def test_streamed_text_stream_sync(): + m = TestModel(custom_output_text='The cat sat on the mat.') + + agent = Agent(m) + + with agent.run_stream_sync('Hello') as result: + # typehint to test (via static typing) that the stream type is correctly inferred + chunks: list[str] = [c for c in result.stream_text_sync()] + # one chunk with `stream_text()` due to group_by_temporal + assert chunks == snapshot(['The cat sat on the mat.']) + assert result.is_complete + + with agent.run_stream_sync('Hello') as result: + # typehint to test (via static typing) that the stream type is correctly inferred + chunks: list[str] = [c for c in result.stream_output_sync()] + # two chunks with `stream()` due to not-final vs. final + assert chunks == snapshot(['The cat sat on the mat.', 'The cat sat on the mat.']) + assert result.is_complete + + with agent.run_stream_sync('Hello') as result: + assert [c for c in result.stream_text_sync(debounce_by=None)] == snapshot( + [ + 'The ', + 'The cat ', + 'The cat sat ', + 'The cat sat on ', + 'The cat sat on the ', + 'The cat sat on the mat.', + ] + ) + + with agent.run_stream_sync('Hello') as result: + # with stream_text, there is no need to do partial validation, so we only get the final message once: + assert [c for c in result.stream_text_sync(delta=False, debounce_by=None)] == snapshot( + ['The ', 'The cat ', 'The cat sat ', 'The cat sat on ', 'The cat sat on the ', 'The cat sat on the mat.'] + ) + + with agent.run_stream_sync('Hello') as result: + assert [c for c in result.stream_text_sync(delta=True, debounce_by=None)] == snapshot( + ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] + ) + + def upcase(text: str) -> str: + return text.upper() + + with agent.run_stream_sync('Hello', output_type=TextOutput(upcase)) as result: + assert [c for c in result.stream_output_sync(debounce_by=None)] == snapshot( + [ + 'THE ', + 'THE CAT ', + 'THE CAT SAT ', + 'THE CAT SAT ON ', + 'THE CAT SAT ON THE ', + 'THE CAT SAT ON THE MAT.', + 'THE CAT SAT ON THE MAT.', + ] + ) + + with agent.run_stream_sync('Hello') as result: + assert [c for c, _is_last in result.stream_responses_sync(debounce_by=None)] == snapshot( + [ + ModelResponse( + parts=[TextPart(content='The ')], + usage=RequestUsage(input_tokens=51, output_tokens=1), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ModelResponse( + parts=[TextPart(content='The cat ')], + usage=RequestUsage(input_tokens=51, output_tokens=2), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ModelResponse( + parts=[TextPart(content='The cat sat ')], + usage=RequestUsage(input_tokens=51, output_tokens=3), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ModelResponse( + parts=[TextPart(content='The cat sat on ')], + usage=RequestUsage(input_tokens=51, output_tokens=4), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ModelResponse( + parts=[TextPart(content='The cat sat on the ')], + usage=RequestUsage(input_tokens=51, output_tokens=5), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ModelResponse( + parts=[TextPart(content='The cat sat on the mat.')], + usage=RequestUsage(input_tokens=51, output_tokens=7), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ModelResponse( + parts=[TextPart(content='The cat sat on the mat.')], + usage=RequestUsage(input_tokens=51, output_tokens=7), + model_name='test', + timestamp=IsNow(tz=timezone.utc), + provider_name='test', + ), + ] + ) + + async def test_plain_response(): call_index = 0