7
7
import json
8
8
import logging
9
9
import mimetypes
10
- from typing import Any , Iterable , Optional , TypedDict , cast
10
+ from typing import Any , Callable , Iterable , Optional , Type , TypedDict , TypeVar , cast
11
11
12
12
import anthropic
13
+ from pydantic import BaseModel
13
14
from typing_extensions import Required , Unpack , override
14
15
16
+ from ..event_loop .streaming import process_stream
17
+ from ..handlers .callback_handler import PrintingCallbackHandler
18
+ from ..tools import convert_pydantic_to_tool_spec
15
19
from ..types .content import ContentBlock , Messages
16
20
from ..types .exceptions import ContextWindowOverflowException , ModelThrottledException
17
21
from ..types .models import Model
20
24
21
25
logger = logging .getLogger (__name__ )
22
26
27
+ T = TypeVar ("T" , bound = BaseModel )
28
+
23
29
24
30
class AnthropicModel (Model ):
25
31
"""Anthropic model provider implementation."""
@@ -356,10 +362,10 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
356
362
with self .client .messages .stream (** request ) as stream :
357
363
for event in stream :
358
364
if event .type in AnthropicModel .EVENT_TYPES :
359
- yield event .dict ()
365
+ yield event .model_dump ()
360
366
361
367
usage = event .message .usage # type: ignore
362
- yield {"type" : "metadata" , "usage" : usage .dict ()}
368
+ yield {"type" : "metadata" , "usage" : usage .model_dump ()}
363
369
364
370
except anthropic .RateLimitError as error :
365
371
raise ModelThrottledException (str (error )) from error
@@ -369,3 +375,42 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
369
375
raise ContextWindowOverflowException (str (error )) from error
370
376
371
377
raise error
378
+
379
+ @override
380
+ def structured_output (
381
+ self , output_model : Type [T ], prompt : Messages , callback_handler : Optional [Callable ] = None
382
+ ) -> T :
383
+ """Get structured output from the model.
384
+
385
+ Args:
386
+ output_model(Type[BaseModel]): The output model to use for the agent.
387
+ prompt(Messages): The prompt messages to use for the agent.
388
+ callback_handler(Optional[Callable]): Optional callback handler for processing events. Defaults to None.
389
+ """
390
+ tool_spec = convert_pydantic_to_tool_spec (output_model )
391
+
392
+ response = self .converse (messages = prompt , tool_specs = [tool_spec ])
393
+ # process the stream and get the tool use input
394
+ results = process_stream (
395
+ response , callback_handler = callback_handler or PrintingCallbackHandler (), messages = prompt
396
+ )
397
+
398
+ stop_reason , messages , _ , _ , _ = results
399
+
400
+ if stop_reason != "tool_use" :
401
+ raise ValueError ("No valid tool use or tool use input was found in the Anthropic response." )
402
+
403
+ content = messages ["content" ]
404
+ output_response : dict [str , Any ] | None = None
405
+ for block in content :
406
+ # if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip.
407
+ # if the tool use name never matches, raise an error.
408
+ if block .get ("toolUse" ) and block ["toolUse" ]["name" ] == tool_spec ["name" ]:
409
+ output_response = block ["toolUse" ]["input" ]
410
+ else :
411
+ continue
412
+
413
+ if output_response is None :
414
+ raise ValueError ("No valid tool use or tool use input was found in the Anthropic response." )
415
+
416
+ return output_model (** output_response )
0 commit comments