Skip to content

Commit ccaf87f

Browse files
committed
Add run_stream_sync
1 parent 78fb707 commit ccaf87f

File tree

3 files changed

+356
-1
lines changed

3 files changed

+356
-1
lines changed

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,86 @@ async def on_complete() -> None:
569569
if not yielded:
570570
raise exceptions.AgentRunError('Agent run finished without producing a final result') # pragma: no cover
571571

572+
@contextmanager
573+
def run_stream_sync(
574+
self,
575+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
576+
*,
577+
output_type: OutputSpec[RunOutputDataT] | None = None,
578+
message_history: Sequence[_messages.ModelMessage] | None = None,
579+
deferred_tool_results: DeferredToolResults | None = None,
580+
model: models.Model | models.KnownModelName | str | None = None,
581+
deps: AgentDepsT = None,
582+
model_settings: ModelSettings | None = None,
583+
usage_limits: _usage.UsageLimits | None = None,
584+
usage: _usage.RunUsage | None = None,
585+
infer_name: bool = True,
586+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
587+
builtin_tools: Sequence[AbstractBuiltinTool] | None = None,
588+
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
589+
) -> Iterator[result.CollectedRunResult[AgentDepsT, Any]]:
590+
"""Run the agent with a user prompt in collected streaming mode.
591+
592+
This method builds an internal agent graph (using system prompts, tools and output schemas) and then
593+
runs the graph until the model produces output matching the `output_type`, for example text or structured data.
594+
At this point, a streaming run result object is collected and -- once this output has completed streaming -- you can iterate over the complete output, message history, and usage.
595+
596+
As this method will consider the first output matching the `output_type` to be the final output,
597+
it will stop running the agent graph and will not execute any tool calls made by the model after this "final" output.
598+
If you want to always run the agent graph to completion and stream events and output at the same time,
599+
use [`agent.run()`][pydantic_ai.agent.AbstractAgent.run] with an `event_stream_handler` or [`agent.iter()`][pydantic_ai.agent.AbstractAgent.iter] instead.
600+
601+
Example:
602+
```python
603+
from pydantic_ai import Agent
604+
605+
agent = Agent('openai:gpt-4o')
606+
607+
def main():
608+
with agent.run_stream_sync('What is the capital of the UK?') as response:
609+
print(response.get_output())
610+
#> The capital of the UK is London.
611+
```
612+
613+
Args:
614+
user_prompt: User input to start/continue the conversation.
615+
output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no
616+
output validators since output validators would expect an argument that matches the agent's output type.
617+
message_history: History of the conversation so far.
618+
deferred_tool_results: Optional results for deferred tool calls in the message history.
619+
model: Optional model to use for this run, required if `model` was not set when creating the agent.
620+
deps: Optional dependencies to use for this run.
621+
model_settings: Optional settings to use for this model's request.
622+
usage_limits: Optional limits on model request count or token usage.
623+
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
624+
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
625+
toolsets: Optional additional toolsets for this run.
626+
builtin_tools: Optional additional builtin tools for this run.
627+
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
628+
It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager.
629+
Note that it does _not_ receive any events after the final result is found.
630+
631+
Returns:
632+
The result of the run.
633+
"""
634+
async_cm = self.run_stream(
635+
user_prompt,
636+
output_type=output_type,
637+
message_history=message_history,
638+
deferred_tool_results=deferred_tool_results,
639+
model=model,
640+
deps=deps,
641+
model_settings=model_settings,
642+
usage_limits=usage_limits,
643+
usage=usage,
644+
infer_name=infer_name,
645+
toolsets=toolsets,
646+
builtin_tools=builtin_tools,
647+
event_stream_handler=event_stream_handler,
648+
)
649+
async_result = get_event_loop().run_until_complete(async_cm.__aenter__())
650+
yield result.CollectedRunResult.from_streamed_result(async_result) # type: ignore[reportReturnType]
651+
572652
@overload
573653
def run_stream_events(
574654
self,

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations as _annotations
22

3-
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
3+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
44
from copy import deepcopy
55
from dataclasses import dataclass, field
66
from datetime import datetime
@@ -9,6 +9,8 @@
99
from pydantic import ValidationError
1010
from typing_extensions import TypeVar, deprecated
1111

12+
from pydantic_graph._utils import get_event_loop
13+
1214
from . import _utils, exceptions, messages as _messages, models
1315
from ._output import (
1416
OutputDataT_inv,
@@ -543,6 +545,85 @@ async def _marked_completed(self, message: _messages.ModelResponse | None = None
543545
await self._on_complete()
544546

545547

548+
@dataclass(init=False)
549+
class CollectedRunResult(StreamedRunResult[AgentDepsT, OutputDataT]):
550+
"""Provides a synchronous API over 'StreamedRunResult' by eagerly loading the stream."""
551+
552+
@classmethod
553+
def from_streamed_result(
554+
cls, streamed_run_result: StreamedRunResult[AgentDepsT, OutputDataT]
555+
) -> CollectedRunResult[AgentDepsT, OutputDataT]:
556+
"""Create a CollectedRunResult from an existing StreamedRunResult."""
557+
instance = cls.__new__(cls)
558+
559+
instance._all_messages = streamed_run_result._all_messages
560+
instance._new_message_index = streamed_run_result._new_message_index
561+
instance._stream_response = streamed_run_result._stream_response
562+
instance._on_complete = streamed_run_result._on_complete
563+
instance._run_result = streamed_run_result._run_result
564+
instance.is_complete = streamed_run_result.is_complete
565+
566+
return instance
567+
568+
def _collect_async_iterator(self, async_iter: AsyncIterator[T]) -> list[T]:
569+
async def collect():
570+
return [item async for item in async_iter]
571+
572+
return get_event_loop().run_until_complete(collect())
573+
574+
def stream_output(self, *, debounce_by: float | None = 0.1) -> Iterator[OutputDataT]: # type: ignore[reportIncompatibleMethodOverride]
575+
"""Collect and stream the output as an iterable.
576+
577+
The pydantic validator for structured data will be called in
578+
[partial mode](https://docs.pydantic.dev/dev/concepts/experimental/#partial-validation)
579+
on each iteration.
580+
581+
Args:
582+
debounce_by: by how much (if at all) to debounce/group the output chunks by. `None` means no debouncing.
583+
Debouncing is particularly important for long structured outputs to reduce the overhead of
584+
performing validation as each token is received.
585+
586+
Returns:
587+
An iterable of the response data.
588+
"""
589+
async_stream = super().stream_output(debounce_by=debounce_by)
590+
yield from self._collect_async_iterator(async_stream)
591+
592+
def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> Iterator[str]: # type: ignore[reportIncompatibleMethodOverride]
593+
"""Collect and stream the text result as an iterable.
594+
595+
!!! note
596+
Result validators will NOT be called on the text result if `delta=True`.
597+
598+
Args:
599+
delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text
600+
up to the current point.
601+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
602+
Debouncing is particularly important for long structured responses to reduce the overhead of
603+
performing validation as each token is received.
604+
"""
605+
async_stream = super().stream_text(delta=delta, debounce_by=debounce_by)
606+
yield from self._collect_async_iterator(async_stream)
607+
608+
def stream_responses(self, *, debounce_by: float | None = 0.1) -> Iterator[tuple[_messages.ModelResponse, bool]]: # type: ignore[reportIncompatibleMethodOverride]
609+
"""Collect and stream the response as an iterable of Structured LLM Messages.
610+
611+
Args:
612+
debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing.
613+
Debouncing is particularly important for long structured responses to reduce the overhead of
614+
performing validation as each token is received.
615+
616+
Returns:
617+
An iterable of the structured response message and whether that is the last message.
618+
"""
619+
async_stream = super().stream_responses(debounce_by=debounce_by)
620+
yield from self._collect_async_iterator(async_stream)
621+
622+
def get_output(self) -> OutputDataT: # type: ignore[reportIncompatibleMethodOverride]
623+
"""Stream the whole response, validate and return it."""
624+
return get_event_loop().run_until_complete(super().get_output())
625+
626+
546627
@dataclass(repr=False)
547628
class FinalResult(Generic[OutputDataT]):
548629
"""Marker class storing the final output of an agent run and associated metadata."""

tests/test_streaming.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,86 @@ async def ret_a(x: str) -> str:
134134
)
135135

136136

137+
def test_streamed_text_sync_response():
138+
m = TestModel()
139+
140+
test_agent = Agent(m)
141+
assert test_agent.name is None
142+
143+
@test_agent.tool_plain
144+
async def ret_a(x: str) -> str:
145+
return f'{x}-apple'
146+
147+
with test_agent.run_stream_sync('Hello') as result:
148+
# assert test_agent.name == 'test_agent'
149+
assert not result.is_complete
150+
assert result.all_messages() == snapshot(
151+
[
152+
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
153+
ModelResponse(
154+
parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())],
155+
usage=RequestUsage(input_tokens=51),
156+
model_name='test',
157+
timestamp=IsNow(tz=timezone.utc),
158+
provider_name='test',
159+
),
160+
ModelRequest(
161+
parts=[
162+
ToolReturnPart(
163+
tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr()
164+
)
165+
]
166+
),
167+
]
168+
)
169+
assert result.usage() == snapshot(
170+
RunUsage(
171+
requests=2,
172+
input_tokens=103,
173+
output_tokens=5,
174+
tool_calls=1,
175+
)
176+
)
177+
response = result.get_output()
178+
assert response == snapshot('{"ret_a":"a-apple"}')
179+
assert result.is_complete
180+
assert result.timestamp() == IsNow(tz=timezone.utc)
181+
assert result.all_messages() == snapshot(
182+
[
183+
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
184+
ModelResponse(
185+
parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'}, tool_call_id=IsStr())],
186+
usage=RequestUsage(input_tokens=51),
187+
model_name='test',
188+
timestamp=IsNow(tz=timezone.utc),
189+
provider_name='test',
190+
),
191+
ModelRequest(
192+
parts=[
193+
ToolReturnPart(
194+
tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr()
195+
)
196+
]
197+
),
198+
ModelResponse(
199+
parts=[TextPart(content='{"ret_a":"a-apple"}')],
200+
usage=RequestUsage(input_tokens=52, output_tokens=11),
201+
model_name='test',
202+
timestamp=IsNow(tz=timezone.utc),
203+
provider_name='test',
204+
),
205+
]
206+
)
207+
assert result.usage() == snapshot(
208+
RunUsage(
209+
requests=2,
210+
input_tokens=103,
211+
output_tokens=11,
212+
tool_calls=1,
213+
)
214+
)
215+
216+
137217
async def test_streamed_structured_response():
138218
m = TestModel()
139219

@@ -302,6 +382,120 @@ def upcase(text: str) -> str:
302382
)
303383

304384

385+
def test_streamed_text_stream_sync():
386+
m = TestModel(custom_output_text='The cat sat on the mat.')
387+
388+
agent = Agent(m)
389+
390+
with agent.run_stream_sync('Hello') as result:
391+
# typehint to test (via static typing) that the stream type is correctly inferred
392+
chunks: list[str] = [c for c in result.stream_text()]
393+
# one chunk with `stream_text()` due to group_by_temporal
394+
assert chunks == snapshot(['The cat sat on the mat.'])
395+
assert result.is_complete
396+
397+
with agent.run_stream_sync('Hello') as result:
398+
# typehint to test (via static typing) that the stream type is correctly inferred
399+
chunks: list[str] = [c for c in result.stream_output()]
400+
# two chunks with `stream()` due to not-final vs. final
401+
assert chunks == snapshot(['The cat sat on the mat.', 'The cat sat on the mat.'])
402+
assert result.is_complete
403+
404+
with agent.run_stream_sync('Hello') as result:
405+
assert [c for c in result.stream_text(debounce_by=None)] == snapshot(
406+
[
407+
'The ',
408+
'The cat ',
409+
'The cat sat ',
410+
'The cat sat on ',
411+
'The cat sat on the ',
412+
'The cat sat on the mat.',
413+
]
414+
)
415+
416+
with agent.run_stream_sync('Hello') as result:
417+
# with stream_text, there is no need to do partial validation, so we only get the final message once:
418+
assert [c for c in result.stream_text(delta=False, debounce_by=None)] == snapshot(
419+
['The ', 'The cat ', 'The cat sat ', 'The cat sat on ', 'The cat sat on the ', 'The cat sat on the mat.']
420+
)
421+
422+
with agent.run_stream_sync('Hello') as result:
423+
assert [c for c in result.stream_text(delta=True, debounce_by=None)] == snapshot(
424+
['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.']
425+
)
426+
427+
def upcase(text: str) -> str:
428+
return text.upper()
429+
430+
with agent.run_stream_sync('Hello', output_type=TextOutput(upcase)) as result:
431+
assert [c for c in result.stream_output(debounce_by=None)] == snapshot(
432+
[
433+
'THE ',
434+
'THE CAT ',
435+
'THE CAT SAT ',
436+
'THE CAT SAT ON ',
437+
'THE CAT SAT ON THE ',
438+
'THE CAT SAT ON THE MAT.',
439+
'THE CAT SAT ON THE MAT.',
440+
]
441+
)
442+
443+
with agent.run_stream_sync('Hello') as result:
444+
assert [c for c, _is_last in result.stream_responses(debounce_by=None)] == snapshot(
445+
[
446+
ModelResponse(
447+
parts=[TextPart(content='The ')],
448+
usage=RequestUsage(input_tokens=51, output_tokens=1),
449+
model_name='test',
450+
timestamp=IsNow(tz=timezone.utc),
451+
provider_name='test',
452+
),
453+
ModelResponse(
454+
parts=[TextPart(content='The cat ')],
455+
usage=RequestUsage(input_tokens=51, output_tokens=2),
456+
model_name='test',
457+
timestamp=IsNow(tz=timezone.utc),
458+
provider_name='test',
459+
),
460+
ModelResponse(
461+
parts=[TextPart(content='The cat sat ')],
462+
usage=RequestUsage(input_tokens=51, output_tokens=3),
463+
model_name='test',
464+
timestamp=IsNow(tz=timezone.utc),
465+
provider_name='test',
466+
),
467+
ModelResponse(
468+
parts=[TextPart(content='The cat sat on ')],
469+
usage=RequestUsage(input_tokens=51, output_tokens=4),
470+
model_name='test',
471+
timestamp=IsNow(tz=timezone.utc),
472+
provider_name='test',
473+
),
474+
ModelResponse(
475+
parts=[TextPart(content='The cat sat on the ')],
476+
usage=RequestUsage(input_tokens=51, output_tokens=5),
477+
model_name='test',
478+
timestamp=IsNow(tz=timezone.utc),
479+
provider_name='test',
480+
),
481+
ModelResponse(
482+
parts=[TextPart(content='The cat sat on the mat.')],
483+
usage=RequestUsage(input_tokens=51, output_tokens=7),
484+
model_name='test',
485+
timestamp=IsNow(tz=timezone.utc),
486+
provider_name='test',
487+
),
488+
ModelResponse(
489+
parts=[TextPart(content='The cat sat on the mat.')],
490+
usage=RequestUsage(input_tokens=51, output_tokens=7),
491+
model_name='test',
492+
timestamp=IsNow(tz=timezone.utc),
493+
provider_name='test',
494+
),
495+
]
496+
)
497+
498+
305499
async def test_plain_response():
306500
call_index = 0
307501

0 commit comments

Comments
 (0)