From 0a8b1d472e9ddd4e61f1f79c14665ecd0d8b0c3c Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Thu, 2 Oct 2025 14:49:34 -0400 Subject: [PATCH 01/13] feat: replace kwargs with invocation_state in agent APIs --- src/strands/agent/agent.py | 52 ++++++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 4579ebacf..a70e4b58f 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -272,6 +272,7 @@ def __init__( Raises: ValueError: If agent id contains path separators. """ + self.invocation_state: dict[str, Any] | None = None self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] @@ -374,7 +375,9 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: + def __call__( + self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -389,6 +392,7 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history + invocation_state: [New] Additional parameters to pass through the event loop. **kwargs: Additional parameters to pass through the event loop. Returns: @@ -399,15 +403,24 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ + if kwargs: + logger.warning("`Kwargs` parameter is deprecated, use the `invocation_state` parameter instead.") + self.invocation_state = kwargs + if invocation_state is not None: + self.invocation_state["invocation_state"] = invocation_state + else: + self.invocation_state = invocation_state def execute() -> AgentResult: - return asyncio.run(self.invoke_async(prompt, **kwargs)) + return asyncio.run(self.invoke_async(prompt, invocation_state=self.invocation_state)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult: + async def invoke_async( + self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any + ) -> AgentResult: """Process a natural language prompt through the agent's event loop. This method implements the conversational interface with multiple input patterns: @@ -422,6 +435,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history + invocation_state: [New] Additional parameters to pass through the event loop. **kwargs: Additional parameters to pass through the event loop. Returns: @@ -432,7 +446,15 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - events = self.stream_async(prompt, **kwargs) + if kwargs: + logger.warning("`Kwargs` parameter is deprecated, use the `invocation_state` parameter instead.") + self.invocation_state = kwargs + if invocation_state is not None: + self.invocation_state["invocation_state"] = invocation_state + else: + self.invocation_state = invocation_state + + events = self.stream_async(prompt, invocation_state=self.invocation_state) async for event in events: _ = event @@ -528,9 +550,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def stream_async( - self, - prompt: AgentInput = None, - **kwargs: Any, + self, prompt: AgentInput = None, *, invocation_state: Optional[dict[str, Any]] = None, **kwargs: Any ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -546,6 +566,7 @@ async def stream_async( - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history + invocation_state: [New] Additional parameters to pass through the event loop. **kwargs: Additional parameters to pass to the event loop. Yields: @@ -567,7 +588,18 @@ async def stream_async( yield event["data"] ``` """ - callback_handler = kwargs.get("callback_handler", self.callback_handler) + if kwargs: + logger.warning("`Kwargs` parameter is deprecated, use the `invocation_state` parameter instead.") + self.invocation_state = kwargs + if invocation_state is not None: + self.invocation_state["invocation_state"] = invocation_state + else: + self.invocation_state = invocation_state or {} + + # Get callback handler from merged state or use default + callback_handler = self.invocation_state.get("invocation_state", {}).get( + "callback_handler", self.invocation_state.get("callback_handler", self.callback_handler) + ) # Process input and get message to add (if any) messages = self._convert_prompt_to_messages(prompt) @@ -576,10 +608,10 @@ async def stream_async( with trace_api.use_span(self.trace_span): try: - events = self._run_loop(messages, invocation_state=kwargs) + events = self._run_loop(messages, invocation_state=self.invocation_state) async for event in events: - event.prepare(invocation_state=kwargs) + event.prepare(invocation_state=self.invocation_state) if event.is_callback_event: as_dict = event.as_dict() From f6ed2f2000fd021067e7479ece7b910eb34f7d26 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Thu, 2 Oct 2025 21:39:48 -0400 Subject: [PATCH 02/13] fix: handle **kwargs in stream_async. --- src/strands/agent/agent.py | 51 ++++++++++++++------------------------ 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index a70e4b58f..f39796b1e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -272,7 +272,6 @@ def __init__( Raises: ValueError: If agent id contains path separators. """ - self.invocation_state: dict[str, Any] | None = None self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] @@ -392,8 +391,8 @@ def __call__( - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - invocation_state: [New] Additional parameters to pass through the event loop. - **kwargs: Additional parameters to pass through the event loop. + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: Result object containing: @@ -403,16 +402,9 @@ def __call__( - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - if kwargs: - logger.warning("`Kwargs` parameter is deprecated, use the `invocation_state` parameter instead.") - self.invocation_state = kwargs - if invocation_state is not None: - self.invocation_state["invocation_state"] = invocation_state - else: - self.invocation_state = invocation_state def execute() -> AgentResult: - return asyncio.run(self.invoke_async(prompt, invocation_state=self.invocation_state)) + return asyncio.run(self.invoke_async(prompt, invocation_state=invocation_state, **kwargs)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) @@ -435,8 +427,8 @@ async def invoke_async( - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - invocation_state: [New] Additional parameters to pass through the event loop. - **kwargs: Additional parameters to pass through the event loop. + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: Result: object containing: @@ -446,15 +438,7 @@ async def invoke_async( - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - if kwargs: - logger.warning("`Kwargs` parameter is deprecated, use the `invocation_state` parameter instead.") - self.invocation_state = kwargs - if invocation_state is not None: - self.invocation_state["invocation_state"] = invocation_state - else: - self.invocation_state = invocation_state - - events = self.stream_async(prompt, invocation_state=self.invocation_state) + events = self.stream_async(prompt, invocation_state=invocation_state, **kwargs) async for event in events: _ = event @@ -566,8 +550,8 @@ async def stream_async( - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history - invocation_state: [New] Additional parameters to pass through the event loop. - **kwargs: Additional parameters to pass to the event loop. + invocation_state: Additional parameters to pass through the event loop. + **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: An async iterator that yields events. Each event is a dictionary containing @@ -588,17 +572,20 @@ async def stream_async( yield event["data"] ``` """ + merged_state = {} if kwargs: - logger.warning("`Kwargs` parameter is deprecated, use the `invocation_state` parameter instead.") - self.invocation_state = kwargs + logger.warning("`**kwargs` parameter is deprecated, use `invocation_state` instead.") + merged_state.update(kwargs) if invocation_state is not None: - self.invocation_state["invocation_state"] = invocation_state + merged_state["invocation_state"] = invocation_state else: - self.invocation_state = invocation_state or {} + if invocation_state is not None: + merged_state["invocation_state"] = invocation_state # Get callback handler from merged state or use default - callback_handler = self.invocation_state.get("invocation_state", {}).get( - "callback_handler", self.invocation_state.get("callback_handler", self.callback_handler) + invocation_state_dict = merged_state.get("invocation_state") or {} + callback_handler = invocation_state_dict.get( + "callback_handler", merged_state.get("callback_handler", self.callback_handler) ) # Process input and get message to add (if any) @@ -608,10 +595,10 @@ async def stream_async( with trace_api.use_span(self.trace_span): try: - events = self._run_loop(messages, invocation_state=self.invocation_state) + events = self._run_loop(messages, invocation_state=merged_state) async for event in events: - event.prepare(invocation_state=self.invocation_state) + event.prepare(invocation_state=merged_state) if event.is_callback_event: as_dict = event.as_dict() From 967e8056c0f5eeb752b511edc11d55130c9e4cc1 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Fri, 3 Oct 2025 11:06:20 -0400 Subject: [PATCH 03/13] feat: add a unit test for the change --- tests/strands/agent/test_agent.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..a5a627632 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1877,3 +1877,17 @@ def test_tool(action: str) -> str: assert '"action": "test_value"' in tool_call_text assert '"agent"' not in tool_call_text assert '"extra_param"' not in tool_call_text + + +def test_agent__call__handles_none_invocation_state(mock_model, agent): + """Test that agent handles None invocation_state without AttributeError.""" + mock_model.mock_stream.return_value = [ + {"contentBlockDelta": {"delta": {"text": "test response"}}}, + {"contentBlockStop": {}}, + ] + + # This should not raise AttributeError: 'NoneType' object has no attribute 'get' + result = agent("test", invocation_state=None) + + assert result.message["content"][0]["text"] == "test response" + assert result.stop_reason == "end_turn" From 63991c6b65cf8402f79b969d04e4a4c3fc1ecaf6 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Mon, 6 Oct 2025 13:42:54 -0400 Subject: [PATCH 04/13] Update src/strands/agent/agent.py Co-authored-by: Nick Clegg --- src/strands/agent/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index f39796b1e..d224658ea 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -534,7 +534,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) async def stream_async( - self, prompt: AgentInput = None, *, invocation_state: Optional[dict[str, Any]] = None, **kwargs: Any + self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. From 3929266ef0178cc4925daffdfdf1355bb20f671a Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 2 Oct 2025 12:17:01 -0400 Subject: [PATCH 05/13] tool - executors - concurrent - remove no-op gather (#954) --- src/strands/tools/executors/concurrent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 767071bae..8ef8a8b65 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -72,8 +72,6 @@ async def _execute( yield event task_events[task_id].set() - asyncio.gather(*tasks) - async def _task( self, agent: "Agent", From ebdeb5f27c814e058fab36aa585ae891268a094d Mon Sep 17 00:00:00 2001 From: poshinchen Date: Thu, 2 Oct 2025 12:35:22 -0400 Subject: [PATCH 06/13] feat(telemetry): updated traces to match OTEL v1.37 semantic conventions (#952) --- src/strands/telemetry/tracer.py | 332 ++++++++++++++++++------ src/strands/types/traces.py | 19 +- tests/strands/telemetry/test_tracer.py | 340 ++++++++++++++++++++++++- 3 files changed, 610 insertions(+), 81 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index d1862b859..b39de27ea 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -6,6 +6,7 @@ import json import logging +import os from datetime import date, datetime, timezone from typing import Any, Dict, Mapping, Optional @@ -17,7 +18,7 @@ from ..types.content import ContentBlock, Message, Messages from ..types.streaming import StopReason, Usage from ..types.tools import ToolResult, ToolUse -from ..types.traces import AttributeValue +from ..types.traces import Attributes, AttributeValue logger = logging.getLogger(__name__) @@ -90,6 +91,19 @@ def __init__( self.tracer = self.tracer_provider.get_tracer(self.service_name) ThreadingInstrumentor().instrument() + # Read OTEL_SEMCONV_STABILITY_OPT_IN environment variable + self.use_latest_genai_conventions = self._parse_semconv_opt_in() + + def _parse_semconv_opt_in(self) -> bool: + """Parse the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. + + Returns: + Set of opt-in values from the environment variable + """ + opt_in_env = os.getenv("OTEL_SEMCONV_STABILITY_OPT_IN", "") + + return "gen_ai_latest_experimental" in opt_in_env + def _start_span( self, span_name: str, @@ -194,7 +208,7 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Optiona error = exception or Exception(error_message) self._end_span(span, error=error) - def _add_event(self, span: Optional[Span], event_name: str, event_attributes: Dict[str, AttributeValue]) -> None: + def _add_event(self, span: Optional[Span], event_name: str, event_attributes: Attributes) -> None: """Add an event with attributes to a span. Args: @@ -249,10 +263,7 @@ def start_model_invoke_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = { - "gen_ai.system": "strands-agents", - "gen_ai.operation.name": "chat", - } + attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat") if model_id: attributes["gen_ai.request.model"] = model_id @@ -261,12 +272,8 @@ def start_model_invoke_span( attributes.update({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool))}) span = self._start_span("chat", parent_span, attributes=attributes, span_kind=trace_api.SpanKind.CLIENT) - for message in messages: - self._add_event( - span, - self._get_event_name_for_message(message), - {"content": serialize(message["content"])}, - ) + self._add_event_messages(span, messages) + return span def end_model_invoke_span( @@ -291,11 +298,28 @@ def end_model_invoke_span( "gen_ai.usage.cache_write_input_tokens": usage.get("cacheWriteInputTokens", 0), } - self._add_event( - span, - "gen_ai.choice", - event_attributes={"finish_reason": str(stop_reason), "message": serialize(message["content"])}, - ) + if self.use_latest_genai_conventions: + self._add_event( + span, + "gen_ai.client.inference.operation.details", + { + "gen_ai.output.messages": serialize( + [ + { + "role": message["role"], + "parts": [{"type": "text", "content": serialize(message["content"])}], + "finish_reason": str(stop_reason), + } + ] + ), + }, + ) + else: + self._add_event( + span, + "gen_ai.choice", + event_attributes={"finish_reason": str(stop_reason), "message": serialize(message["content"])}, + ) self._end_span(span, attributes, error) @@ -310,12 +334,13 @@ def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = { - "gen_ai.operation.name": "execute_tool", - "gen_ai.system": "strands-agents", - "gen_ai.tool.name": tool["name"], - "gen_ai.tool.call.id": tool["toolUseId"], - } + attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_tool") + attributes.update( + { + "gen_ai.tool.name": tool["name"], + "gen_ai.tool.call.id": tool["toolUseId"], + } + ) # Add additional kwargs as attributes attributes.update(kwargs) @@ -323,15 +348,38 @@ def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None span_name = f"execute_tool {tool['name']}" span = self._start_span(span_name, parent_span, attributes=attributes, span_kind=trace_api.SpanKind.INTERNAL) - self._add_event( - span, - "gen_ai.tool.message", - event_attributes={ - "role": "tool", - "content": serialize(tool["input"]), - "id": tool["toolUseId"], - }, - ) + if self.use_latest_genai_conventions: + self._add_event( + span, + "gen_ai.client.inference.operation.details", + { + "gen_ai.input.messages": serialize( + [ + { + "role": "tool", + "parts": [ + { + "type": "tool_call", + "name": tool["name"], + "id": tool["toolUseId"], + "arguments": [{"content": serialize(tool["input"])}], + } + ], + } + ] + ) + }, + ) + else: + self._add_event( + span, + "gen_ai.tool.message", + event_attributes={ + "role": "tool", + "content": serialize(tool["input"]), + "id": tool["toolUseId"], + }, + ) return span @@ -352,18 +400,40 @@ def end_tool_call_span( attributes.update( { - "tool.status": status_str, + "gen_ai.tool.status": status_str, } ) - self._add_event( - span, - "gen_ai.choice", - event_attributes={ - "message": serialize(tool_result.get("content")), - "id": tool_result.get("toolUseId", ""), - }, - ) + if self.use_latest_genai_conventions: + self._add_event( + span, + "gen_ai.client.inference.operation.details", + { + "gen_ai.output.messages": serialize( + [ + { + "role": "tool", + "parts": [ + { + "type": "tool_call_response", + "id": tool_result.get("toolUseId", ""), + "result": serialize(tool_result.get("content")), + } + ], + } + ] + ) + }, + ) + else: + self._add_event( + span, + "gen_ai.choice", + event_attributes={ + "message": serialize(tool_result.get("content")), + "id": tool_result.get("toolUseId", ""), + }, + ) self._end_span(span, attributes, error) @@ -400,12 +470,7 @@ def start_event_loop_cycle_span( span_name = "execute_event_loop_cycle" span = self._start_span(span_name, parent_span, attributes) - for message in messages or []: - self._add_event( - span, - self._get_event_name_for_message(message), - {"content": serialize(message["content"])}, - ) + self._add_event_messages(span, messages) return span @@ -429,7 +494,24 @@ def end_event_loop_cycle_span( if tool_result_message: event_attributes["tool.result"] = serialize(tool_result_message["content"]) - self._add_event(span, "gen_ai.choice", event_attributes=event_attributes) + + if self.use_latest_genai_conventions: + self._add_event( + span, + "gen_ai.client.inference.operation.details", + { + "gen_ai.output.messages": serialize( + [ + { + "role": tool_result_message["role"], + "parts": [{"type": "text", "content": serialize(tool_result_message["content"])}], + } + ] + ) + }, + ) + else: + self._add_event(span, "gen_ai.choice", event_attributes=event_attributes) self._end_span(span, attributes, error) def start_agent_span( @@ -454,11 +536,12 @@ def start_agent_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = { - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": agent_name, - "gen_ai.operation.name": "invoke_agent", - } + attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="invoke_agent") + attributes.update( + { + "gen_ai.agent.name": agent_name, + } + ) if model_id: attributes["gen_ai.request.model"] = model_id @@ -477,12 +560,7 @@ def start_agent_span( span = self._start_span( f"invoke_agent {agent_name}", attributes=attributes, span_kind=trace_api.SpanKind.CLIENT ) - for message in messages: - self._add_event( - span, - self._get_event_name_for_message(message), - {"content": serialize(message["content"])}, - ) + self._add_event_messages(span, messages) return span @@ -502,11 +580,28 @@ def end_agent_span( attributes: Dict[str, AttributeValue] = {} if response: - self._add_event( - span, - "gen_ai.choice", - event_attributes={"message": str(response), "finish_reason": str(response.stop_reason)}, - ) + if self.use_latest_genai_conventions: + self._add_event( + span, + "gen_ai.client.inference.operation.details", + { + "gen_ai.output.messages": serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": str(response)}], + "finish_reason": str(response.stop_reason), + } + ] + ) + }, + ) + else: + self._add_event( + span, + "gen_ai.choice", + event_attributes={"message": str(response), "finish_reason": str(response.stop_reason)}, + ) if hasattr(response, "metrics") and hasattr(response.metrics, "accumulated_usage"): accumulated_usage = response.metrics.accumulated_usage @@ -530,19 +625,33 @@ def start_multiagent_span( instance: str, ) -> Span: """Start a new span for swarm invocation.""" - attributes: Dict[str, AttributeValue] = { - "gen_ai.system": "strands-agents", - "gen_ai.agent.name": instance, - "gen_ai.operation.name": f"invoke_{instance}", - } + operation = f"invoke_{instance}" + attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation) + attributes.update( + { + "gen_ai.agent.name": instance, + } + ) - span = self._start_span(f"invoke_{instance}", attributes=attributes, span_kind=trace_api.SpanKind.CLIENT) + span = self._start_span(operation, attributes=attributes, span_kind=trace_api.SpanKind.CLIENT) content = serialize(task) if isinstance(task, list) else task - self._add_event( - span, - "gen_ai.user.message", - event_attributes={"content": content}, - ) + + if self.use_latest_genai_conventions: + self._add_event( + span, + "gen_ai.client.inference.operation.details", + { + "gen_ai.input.messages": serialize( + [{"role": "user", "parts": [{"type": "text", "content": content}]}] + ) + }, + ) + else: + self._add_event( + span, + "gen_ai.user.message", + event_attributes={"content": content}, + ) return span @@ -553,11 +662,78 @@ def end_swarm_span( ) -> None: """End a swarm span with results.""" if result: + if self.use_latest_genai_conventions: + self._add_event( + span, + "gen_ai.client.inference.operation.details", + { + "gen_ai.output.messages": serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": result}], + } + ] + ) + }, + ) + else: + self._add_event( + span, + "gen_ai.choice", + event_attributes={"message": result}, + ) + + def _get_common_attributes( + self, + operation_name: str, + ) -> Dict[str, AttributeValue]: + """Returns a dictionary of common attributes based on the convention version used. + + Args: + operation_name: The name of the operation. + + Returns: + A dictionary of attributes following the appropriate GenAI conventions. + """ + common_attributes = {"gen_ai.operation.name": operation_name} + if self.use_latest_genai_conventions: + common_attributes.update( + { + "gen_ai.provider.name": "strands-agents", + } + ) + else: + common_attributes.update( + { + "gen_ai.system": "strands-agents", + } + ) + return dict(common_attributes) + + def _add_event_messages(self, span: Span, messages: Messages) -> None: + """Adds messages as event to the provided span based on the current GenAI conventions. + + Args: + span: The span to which events will be added. + messages: List of messages being sent to the agent. + """ + if self.use_latest_genai_conventions: + input_messages: list = [] + for message in messages: + input_messages.append( + {"role": message["role"], "parts": [{"type": "text", "content": serialize(message["content"])}]} + ) self._add_event( - span, - "gen_ai.choice", - event_attributes={"message": result}, + span, "gen_ai.client.inference.operation.details", {"gen_ai.input.messages": serialize(input_messages)} ) + else: + for message in messages: + self._add_event( + span, + self._get_event_name_for_message(message), + {"content": serialize(message["content"])}, + ) # Singleton instance for global access diff --git a/src/strands/types/traces.py b/src/strands/types/traces.py index b850196ae..af6188adb 100644 --- a/src/strands/types/traces.py +++ b/src/strands/types/traces.py @@ -1,5 +1,20 @@ """Tracing type definitions for the SDK.""" -from typing import List, Union +from typing import List, Mapping, Optional, Sequence, Union -AttributeValue = Union[str, bool, float, int, List[str], List[bool], List[float], List[int]] +AttributeValue = Union[ + str, + bool, + float, + int, + List[str], + List[bool], + List[float], + List[int], + Sequence[str], + Sequence[bool], + Sequence[int], + Sequence[float], +] + +Attributes = Optional[Mapping[str, AttributeValue]] diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 8c4f9ae20..eed060294 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -163,6 +163,43 @@ def test_start_model_invoke_span(mock_tracer): assert span is not None +def test_start_model_invoke_span_latest_conventions(mock_tracer): + """Test starting a model invoke span with the latest semantic conventions.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.use_latest_genai_conventions = True + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + model_id = "test-model" + + span = tracer.start_model_invoke_span(messages=messages, agent_name="TestAgent", model_id=model_id) + + mock_tracer.start_span.assert_called_once() + assert mock_tracer.start_span.call_args[1]["name"] == "chat" + assert mock_tracer.start_span.call_args[1]["kind"] == SpanKind.CLIENT + mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") + mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "chat") + mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.input.messages": serialize( + [ + { + "role": messages[0]["role"], + "parts": [{"type": "text", "content": serialize(messages[0]["content"])}], + } + ] + ) + }, + ) + assert span is not None + + def test_end_model_invoke_span(mock_span): """Test ending a model invoke span.""" tracer = Tracer() @@ -187,6 +224,43 @@ def test_end_model_invoke_span(mock_span): mock_span.end.assert_called_once() +def test_end_model_invoke_span_latest_conventions(mock_span): + """Test ending a model invoke span with the latest semantic conventions.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.use_latest_genai_conventions = True + message = {"role": "assistant", "content": [{"text": "Response"}]} + usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) + stop_reason: StopReason = "end_turn" + + tracer.end_model_invoke_span(mock_span, message, usage, stop_reason) + + mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 10) + mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 10) + mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 20) + mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 20) + mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 30) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.output.messages": serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": serialize(message["content"])}], + "finish_reason": "end_turn", + } + ] + ), + }, + ) + + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() + + def test_start_tool_call_span(mock_tracer): """Test starting a tool call span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -212,6 +286,49 @@ def test_start_tool_call_span(mock_tracer): assert span is not None +def test_start_tool_call_span_latest_conventions(mock_tracer): + """Test starting a tool call span with the latest semantic conventions.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.use_latest_genai_conventions = True + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + tool = {"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}} + + span = tracer.start_tool_call_span(tool) + + mock_tracer.start_span.assert_called_once() + assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool" + mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool") + mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") + mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool") + mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "123") + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.input.messages": serialize( + [ + { + "role": "tool", + "parts": [ + { + "type": "tool_call", + "name": tool["name"], + "id": tool["toolUseId"], + "arguments": [{"content": serialize(tool["input"])}], + } + ], + } + ] + ) + }, + ) + assert span is not None + + def test_start_swarm_call_span_with_string_task(mock_tracer): """Test starting a swarm call span with task as string.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -258,6 +375,36 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): assert span is not None +def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer): + """Test starting a swarm call span with task as list of contentBlock with latest semantic conventions.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.use_latest_genai_conventions = True + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + task = [ContentBlock(text="Original Task: foo bar")] + + span = tracer.start_multiagent_span(task, "swarm") + + mock_tracer.start_span.assert_called_once() + assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm" + mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") + mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm") + mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm") + mock_span.add_event.assert_any_call( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.input.messages": serialize( + [{"role": "user", "parts": [{"type": "text", "content": '[{"text": "Original Task: foo bar"}]'}]}] + ) + }, + ) + assert span is not None + + def test_end_swarm_span(mock_span): """Test ending a tool call span.""" tracer = Tracer() @@ -271,6 +418,29 @@ def test_end_swarm_span(mock_span): ) +def test_end_swarm_span_latest_conventions(mock_span): + """Test ending a tool call span with latest semantic conventions.""" + tracer = Tracer() + tracer.use_latest_genai_conventions = True + swarm_final_reuslt = "foo bar bar" + + tracer.end_swarm_span(mock_span, swarm_final_reuslt) + + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.output.messages": serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": "foo bar bar"}], + } + ] + ) + }, + ) + + def test_start_graph_call_span(mock_tracer): """Test starting a graph call span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -303,7 +473,7 @@ def test_end_tool_call_span(mock_span): tracer.end_tool_call_span(mock_span, tool_result) - mock_span.set_attribute.assert_any_call("tool.status", "success") + mock_span.set_attribute.assert_any_call("gen_ai.tool.status", "success") mock_span.add_event.assert_called_with( "gen_ai.choice", attributes={"message": json.dumps(tool_result.get("content")), "id": ""}, @@ -312,6 +482,38 @@ def test_end_tool_call_span(mock_span): mock_span.end.assert_called_once() +def test_end_tool_call_span_latest_conventions(mock_span): + """Test ending a tool call span with the latest semantic conventions.""" + tracer = Tracer() + tracer.use_latest_genai_conventions = True + tool_result = {"status": "success", "content": [{"text": "Tool result"}]} + + tracer.end_tool_call_span(mock_span, tool_result) + + mock_span.set_attribute.assert_any_call("gen_ai.tool.status", "success") + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.output.messages": serialize( + [ + { + "role": "tool", + "parts": [ + { + "type": "tool_call_response", + "id": tool_result.get("toolUseId", ""), + "result": serialize(tool_result.get("content")), + } + ], + } + ] + ) + }, + ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() + + def test_start_event_loop_cycle_span(mock_tracer): """Test starting an event loop cycle span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -335,6 +537,35 @@ def test_start_event_loop_cycle_span(mock_tracer): assert span is not None +def test_start_event_loop_cycle_span_latest_conventions(mock_tracer): + """Test starting an event loop cycle span with the latest semantic conventions.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.use_latest_genai_conventions = True + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + event_loop_kwargs = {"event_loop_cycle_id": "cycle-123"} + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + span = tracer.start_event_loop_cycle_span(event_loop_kwargs, messages=messages) + + mock_tracer.start_span.assert_called_once() + assert mock_tracer.start_span.call_args[1]["name"] == "execute_event_loop_cycle" + mock_span.set_attribute.assert_any_call("event_loop.cycle_id", "cycle-123") + mock_span.add_event.assert_any_call( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.input.messages": serialize( + [{"role": "user", "parts": [{"type": "text", "content": serialize(messages[0]["content"])}]}] + ) + }, + ) + assert span is not None + + def test_end_event_loop_cycle_span(mock_span): """Test ending an event loop cycle span.""" tracer = Tracer() @@ -354,6 +585,32 @@ def test_end_event_loop_cycle_span(mock_span): mock_span.end.assert_called_once() +def test_end_event_loop_cycle_span_latest_conventions(mock_span): + """Test ending an event loop cycle span with the latest semantic conventions.""" + tracer = Tracer() + tracer.use_latest_genai_conventions = True + message = {"role": "assistant", "content": [{"text": "Response"}]} + tool_result_message = {"role": "assistant", "content": [{"toolResult": {"response": "Success"}}]} + + tracer.end_event_loop_cycle_span(mock_span, message, tool_result_message) + + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.output.messages": serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": serialize(tool_result_message["content"])}], + } + ] + ) + }, + ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() + + def test_start_agent_span(mock_tracer): """Test starting an agent span.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): @@ -386,6 +643,46 @@ def test_start_agent_span(mock_tracer): assert span is not None +def test_start_agent_span_latest_conventions(mock_tracer): + """Test starting an agent span with the latest semantic conventions.""" + with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + tracer = Tracer() + tracer.use_latest_genai_conventions = True + tracer.tracer = mock_tracer + + mock_span = mock.MagicMock() + mock_tracer.start_span.return_value = mock_span + + content = [{"text": "test prompt"}] + model_id = "test-model" + tools = [{"name": "weather_tool"}] + custom_attrs = {"custom_attr": "value"} + + span = tracer.start_agent_span( + custom_trace_attributes=custom_attrs, + agent_name="WeatherAgent", + messages=[{"content": content, "role": "user"}], + model_id=model_id, + tools=tools, + ) + + mock_tracer.start_span.assert_called_once() + assert mock_tracer.start_span.call_args[1]["name"] == "invoke_agent WeatherAgent" + mock_span.set_attribute.assert_any_call("gen_ai.provider.name", "strands-agents") + mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "WeatherAgent") + mock_span.set_attribute.assert_any_call("gen_ai.request.model", model_id) + mock_span.set_attribute.assert_any_call("custom_attr", "value") + mock_span.add_event.assert_any_call( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.input.messages": serialize( + [{"role": "user", "parts": [{"type": "text", "content": '[{"text": "test prompt"}]'}]}] + ) + }, + ) + assert span is not None + + def test_end_agent_span(mock_span): """Test ending an agent span.""" tracer = Tracer() @@ -416,6 +713,47 @@ def test_end_agent_span(mock_span): mock_span.end.assert_called_once() +def test_end_agent_span_latest_conventions(mock_span): + """Test ending an agent span with the latest semantic conventions.""" + tracer = Tracer() + tracer.use_latest_genai_conventions = True + + # Mock AgentResult with metrics + mock_metrics = mock.MagicMock() + mock_metrics.accumulated_usage = {"inputTokens": 50, "outputTokens": 100, "totalTokens": 150} + + mock_response = mock.MagicMock() + mock_response.metrics = mock_metrics + mock_response.stop_reason = "end_turn" + mock_response.__str__ = mock.MagicMock(return_value="Agent response") + + tracer.end_agent_span(mock_span, mock_response) + + mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) + mock_span.set_attribute.assert_any_call("gen_ai.usage.input_tokens", 50) + mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) + mock_span.set_attribute.assert_any_call("gen_ai.usage.output_tokens", 100) + mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 0) + mock_span.set_attribute.assert_any_call("gen_ai.usage.cache_write_input_tokens", 0) + mock_span.add_event.assert_called_with( + "gen_ai.client.inference.operation.details", + attributes={ + "gen_ai.output.messages": serialize( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": "Agent response"}], + "finish_reason": "end_turn", + } + ] + ) + }, + ) + mock_span.set_status.assert_called_once_with(StatusCode.OK) + mock_span.end.assert_called_once() + + def test_end_model_invoke_span_with_cache_metrics(mock_span): """Test ending a model invoke span with cache metrics.""" tracer = Tracer() From b01eeda4d5668f8cd3a35f40b924f9e816080a87 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 2 Oct 2025 14:20:58 -0400 Subject: [PATCH 07/13] event loop - handle model execution (#958) --- src/strands/event_loop/event_loop.py | 234 +++++++++++++++------------ 1 file changed, 135 insertions(+), 99 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index f2eed063c..d6367e9d9 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -17,7 +17,7 @@ from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent from ..telemetry.metrics import Trace -from ..telemetry.tracer import get_tracer +from ..telemetry.tracer import Tracer, get_tracer from ..tools._validator import validate_and_prepare_tools from ..types._events import ( EventLoopStopEvent, @@ -37,7 +37,7 @@ MaxTokensReachedException, ModelThrottledException, ) -from ..types.streaming import Metrics, StopReason +from ..types.streaming import StopReason from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from .streaming import stream_messages @@ -106,16 +106,142 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> ) invocation_state["event_loop_cycle_span"] = cycle_span + model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer) + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + + stop_reason, message, *_ = model_event["stop"] + yield ModelMessageEvent(message=message) + + try: + if stop_reason == "max_tokens": + """ + Handle max_tokens limit reached by the model. + + When the model reaches its maximum token limit, this represents a potentially unrecoverable + state where the model's response was truncated. By default, Strands fails hard with an + MaxTokensReachedException to maintain consistency with other failure types. + """ + raise MaxTokensReachedException( + message=( + "Agent has reached an unrecoverable state due to max_tokens limit. " + "For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" + ) + ) + + # If the model is requesting to use tools + if stop_reason == "tool_use": + # Handle tool execution + tool_events = _handle_tool_execution( + stop_reason, + message, + agent=agent, + cycle_trace=cycle_trace, + cycle_span=cycle_span, + cycle_start_time=cycle_start_time, + invocation_state=invocation_state, + ) + async for tool_event in tool_events: + yield tool_event + + return + + # End the cycle and return results + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) + if cycle_span: + tracer.end_event_loop_cycle_span( + span=cycle_span, + message=message, + ) + except EventLoopException as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + + # Don't yield or log the exception - we already did it when we + # raised the exception and we don't need that duplication. + raise + except (ContextWindowOverflowException, MaxTokensReachedException) as e: + # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + raise e + except Exception as e: + if cycle_span: + tracer.end_span_with_error(cycle_span, str(e), e) + + # Handle any other exceptions + yield ForceStopEvent(reason=e) + logger.exception("cycle failed") + raise EventLoopException(e, invocation_state["request_state"]) from e + + yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) + + +async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: + """Make a recursive call to event_loop_cycle with the current state. + + This function is used when the event loop needs to continue processing after tool execution. + + Args: + agent: Agent for which the recursive call is being made. + invocation_state: Arguments to pass through event_loop_cycle + + + Yields: + Results from event_loop_cycle where the last result contains: + + - StopReason: Reason the model stopped generating + - Message: The generated message from the model + - EventLoopMetrics: Updated metrics for the event loop + - Any: Updated request state + """ + cycle_trace = invocation_state["event_loop_cycle_trace"] + + # Recursive call trace + recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id) + cycle_trace.add_child(recursive_trace) + + yield StartEvent() + + events = event_loop_cycle(agent=agent, invocation_state=invocation_state) + async for event in events: + yield event + + recursive_trace.end() + + +async def _handle_model_execution( + agent: "Agent", + cycle_span: Any, + cycle_trace: Trace, + invocation_state: dict[str, Any], + tracer: Tracer, +) -> AsyncGenerator[TypedEvent, None]: + """Handle model execution with retry logic for throttling exceptions. + + Executes the model inference with automatic retry handling for throttling exceptions. + Manages tracing, hooks, and metrics collection throughout the process. + + Args: + agent: The agent executing the model. + cycle_span: Span object for tracing the cycle. + cycle_trace: Trace object for the current event loop cycle. + invocation_state: State maintained across cycles. + tracer: Tracer instance for span management. + + Yields: + Model stream events and throttle events during retries. + + Raises: + ModelThrottledException: If max retry attempts are exceeded. + Exception: Any other model execution errors. + """ # Create a trace for the stream_messages call stream_trace = Trace("stream_messages", parent_id=cycle_trace.id) cycle_trace.add_child(stream_trace) - # Process messages with exponential backoff for throttling - message: Message - stop_reason: StopReason - usage: Any - metrics: Metrics - # Retry loop for handling throttling exceptions current_delay = INITIAL_DELAY for attempt in range(MAX_ATTEMPTS): @@ -136,8 +262,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> try: async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): - if not isinstance(event, ModelStopReason): - yield event + yield event stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -198,108 +323,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> # Add the response message to the conversation agent.messages.append(message) agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) - yield ModelMessageEvent(message=message) # Update metrics agent.event_loop_metrics.update_usage(usage) agent.event_loop_metrics.update_metrics(metrics) - if stop_reason == "max_tokens": - """ - Handle max_tokens limit reached by the model. - - When the model reaches its maximum token limit, this represents a potentially unrecoverable - state where the model's response was truncated. By default, Strands fails hard with an - MaxTokensReachedException to maintain consistency with other failure types. - """ - raise MaxTokensReachedException( - message=( - "Agent has reached an unrecoverable state due to max_tokens limit. " - "For more information see: " - "https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception" - ) - ) - - # If the model is requesting to use tools - if stop_reason == "tool_use": - # Handle tool execution - events = _handle_tool_execution( - stop_reason, - message, - agent=agent, - cycle_trace=cycle_trace, - cycle_span=cycle_span, - cycle_start_time=cycle_start_time, - invocation_state=invocation_state, - ) - async for typed_event in events: - yield typed_event - - return - - # End the cycle and return results - agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes) - if cycle_span: - tracer.end_event_loop_cycle_span( - span=cycle_span, - message=message, - ) - except EventLoopException as e: - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) - - # Don't yield or log the exception - we already did it when we - # raised the exception and we don't need that duplication. - raise - except (ContextWindowOverflowException, MaxTokensReachedException) as e: - # Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException - if cycle_span: - tracer.end_span_with_error(cycle_span, str(e), e) - raise e except Exception as e: if cycle_span: tracer.end_span_with_error(cycle_span, str(e), e) - # Handle any other exceptions yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e - yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"]) - - -async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]: - """Make a recursive call to event_loop_cycle with the current state. - - This function is used when the event loop needs to continue processing after tool execution. - - Args: - agent: Agent for which the recursive call is being made. - invocation_state: Arguments to pass through event_loop_cycle - - - Yields: - Results from event_loop_cycle where the last result contains: - - - StopReason: Reason the model stopped generating - - Message: The generated message from the model - - EventLoopMetrics: Updated metrics for the event loop - - Any: Updated request state - """ - cycle_trace = invocation_state["event_loop_cycle_trace"] - - # Recursive call trace - recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id) - cycle_trace.add_child(recursive_trace) - - yield StartEvent() - - events = event_loop_cycle(agent=agent, invocation_state=invocation_state) - async for event in events: - yield event - - recursive_trace.end() - async def _handle_tool_execution( stop_reason: StopReason, From b50bca5d200f2e52ac08f1f36dd06b0224d88159 Mon Sep 17 00:00:00 2001 From: Vamil Gandhi Date: Thu, 2 Oct 2025 16:52:49 -0400 Subject: [PATCH 08/13] feat: implement concurrent message reading for session managers (#897) Replace sequential message loading with async concurrent reading in both S3SessionManager and FileSessionManager to improve performance for long conversations. Uses asyncio.gather() with run_in_executor() to read multiple messages simultaneously while maintaining proper ordering. Resolves: #874 Co-authored-by: Vamil Gandhi --- src/strands/session/file_session_manager.py | 20 ++++++++++++---- src/strands/session/s3_session_manager.py | 26 ++++++++++++++------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 491f7ad60..93adeb7f2 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -1,5 +1,6 @@ """File-based session manager for local filesystem storage.""" +import asyncio import json import logging import os @@ -231,11 +232,20 @@ def list_messages( else: message_files = message_files[offset:] - # Load only the message files - messages: list[SessionMessage] = [] - for filename in message_files: + return asyncio.run(self._load_messages_concurrently(messages_dir, message_files)) + + async def _load_messages_concurrently(self, messages_dir: str, message_files: list[str]) -> list[SessionMessage]: + """Load multiple message files concurrently using async.""" + if not message_files: + return [] + + async def load_message(filename: str) -> SessionMessage: file_path = os.path.join(messages_dir, filename) - message_data = self._read_file(file_path) - messages.append(SessionMessage.from_dict(message_data)) + loop = asyncio.get_event_loop() + message_data = await loop.run_in_executor(None, self._read_file, file_path) + return SessionMessage.from_dict(message_data) + + tasks = [load_message(filename) for filename in message_files] + messages = await asyncio.gather(*tasks) return messages diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index c6ce28d80..1f6ffe7f1 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -1,5 +1,6 @@ """S3-based session manager for cloud storage.""" +import asyncio import json import logging from typing import Any, Dict, List, Optional, cast @@ -283,14 +284,23 @@ def list_messages( else: message_keys = message_keys[offset:] - # Load only the required message objects - messages: List[SessionMessage] = [] - for key in message_keys: - message_data = self._read_s3_object(key) - if message_data: - messages.append(SessionMessage.from_dict(message_data)) - - return messages + # Load message objects concurrently using async + return asyncio.run(self._load_messages_concurrently(message_keys)) except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e + + async def _load_messages_concurrently(self, message_keys: List[str]) -> List[SessionMessage]: + """Load multiple message objects concurrently using async.""" + if not message_keys: + return [] + + async def load_message(key: str) -> Optional[SessionMessage]: + loop = asyncio.get_event_loop() + message_data = await loop.run_in_executor(None, self._read_s3_object, key) + return SessionMessage.from_dict(message_data) if message_data else None + + tasks = [load_message(key) for key in message_keys] + loaded_messages = await asyncio.gather(*tasks) + + return [msg for msg in loaded_messages if msg is not None] From 9900ccfb24d728fa17f7feec2d99f2a9c9a280a8 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 3 Oct 2025 17:29:43 -0400 Subject: [PATCH 09/13] hooks - before tool call event - cancel tool (#964) --- src/strands/hooks/events.py | 8 +++- src/strands/tools/executors/_executor.py | 29 ++++++++++++++- src/strands/types/_events.py | 23 ++++++++++++ .../strands/tools/executors/test_executor.py | 37 ++++++++++++++++++- tests_integ/tools/executors/conftest.py | 15 ++++++++ .../tools/executors/test_concurrent.py | 16 ++++++++ .../tools/executors/test_sequential.py | 16 ++++++++ 7 files changed, 140 insertions(+), 4 deletions(-) create mode 100644 tests_integ/tools/executors/conftest.py diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index b3b2014f3..8f611e4e2 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -97,14 +97,18 @@ class BeforeToolCallEvent(HookEvent): to change which tool gets executed. This may be None if tool lookup failed. tool_use: The tool parameters that will be passed to selected_tool. invocation_state: Keyword arguments that will be passed to the tool. + cancel_tool: A user defined message that when set, will cancel the tool call. + The message will be placed into a tool result with an error status. If set to `True`, Strands will cancel + the tool call and use a default cancel message. """ selected_tool: Optional[AgentTool] tool_use: ToolUse invocation_state: dict[str, Any] + cancel_tool: bool | str = False def _can_write(self, name: str) -> bool: - return name in ["selected_tool", "tool_use"] + return name in ["cancel_tool", "selected_tool", "tool_use"] @dataclass @@ -124,6 +128,7 @@ class AfterToolCallEvent(HookEvent): invocation_state: Keyword arguments that were passed to the tool result: The result of the tool invocation. Either a ToolResult on success or an Exception if the tool execution failed. + cancel_message: The cancellation message if the user cancelled the tool call. """ selected_tool: Optional[AgentTool] @@ -131,6 +136,7 @@ class AfterToolCallEvent(HookEvent): invocation_state: dict[str, Any] result: ToolResult exception: Optional[Exception] = None + cancel_message: str | None = None def _can_write(self, name: str) -> bool: return name == "result" diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 2a75c48f2..f78861f81 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -14,7 +14,7 @@ from ...hooks import AfterToolCallEvent, BeforeToolCallEvent from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer -from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent +from ...types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse @@ -81,6 +81,31 @@ async def _stream( ) ) + if before_event.cancel_tool: + cancel_message = ( + before_event.cancel_tool if isinstance(before_event.cancel_tool, str) else "tool cancelled by user" + ) + yield ToolCancelEvent(tool_use, cancel_message) + + cancel_result: ToolResult = { + "toolUseId": str(tool_use.get("toolUseId")), + "status": "error", + "content": [{"text": cancel_message}], + } + after_event = agent.hooks.invoke_callbacks( + AfterToolCallEvent( + agent=agent, + tool_use=tool_use, + invocation_state=invocation_state, + selected_tool=None, + result=cancel_result, + cancel_message=cancel_message, + ) + ) + yield ToolResultEvent(after_event.result) + tool_results.append(after_event.result) + return + try: selected_tool = before_event.selected_tool tool_use = before_event.tool_use @@ -123,7 +148,7 @@ async def _stream( # so that we don't needlessly yield ToolStreamEvents for non-generator callbacks. # In which case, as soon as we get a ToolResultEvent we're done and for ToolStreamEvent # we yield it directly; all other cases (non-sdk AgentTools), we wrap events in - # ToolStreamEvent and the last even is just the result + # ToolStreamEvent and the last event is just the result. if isinstance(event, ToolResultEvent): # below the last "event" must point to the tool_result diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 3d0f1d0f0..e20bf658a 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -298,6 +298,29 @@ def tool_use_id(self) -> str: return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId")) +class ToolCancelEvent(TypedEvent): + """Event emitted when a user cancels a tool call from their BeforeToolCallEvent hook.""" + + def __init__(self, tool_use: ToolUse, message: str) -> None: + """Initialize with tool streaming data. + + Args: + tool_use: Information about the tool being cancelled + message: The tool cancellation message + """ + super().__init__({"tool_cancel_event": {"tool_use": tool_use, "message": message}}) + + @property + def tool_use_id(self) -> str: + """The id of the tool cancelled.""" + return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancelled_event")).get("tool_use")).get("toolUseId")) + + @property + def message(self) -> str: + """The tool cancellation message.""" + return cast(str, self["message"]) + + class ModelMessageEvent(TypedEvent): """Event emitted when the model invocation has completed. diff --git a/tests/strands/tools/executors/test_executor.py b/tests/strands/tools/executors/test_executor.py index 3bbedb477..2a0a44e10 100644 --- a/tests/strands/tools/executors/test_executor.py +++ b/tests/strands/tools/executors/test_executor.py @@ -7,7 +7,7 @@ from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent from strands.telemetry.metrics import Trace from strands.tools.executors._executor import ToolExecutor -from strands.types._events import ToolResultEvent, ToolStreamEvent +from strands.types._events import ToolCancelEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import ToolUse @@ -215,3 +215,38 @@ async def test_executor_stream_with_trace( cycle_trace.add_child.assert_called_once() assert isinstance(cycle_trace.add_child.call_args[0][0], Trace) + + +@pytest.mark.parametrize( + ("cancel_tool", "cancel_message"), + [(True, "tool cancelled by user"), ("user cancel message", "user cancel message")], +) +@pytest.mark.asyncio +async def test_executor_stream_cancel( + cancel_tool, cancel_message, executor, agent, tool_results, invocation_state, alist +): + def cancel_callback(event): + event.cancel_tool = cancel_tool + return event + + agent.hooks.add_callback(BeforeToolCallEvent, cancel_callback) + tool_use: ToolUse = {"name": "weather_tool", "toolUseId": "1", "input": {}} + + stream = executor._stream(agent, tool_use, tool_results, invocation_state) + + tru_events = await alist(stream) + exp_events = [ + ToolCancelEvent(tool_use, cancel_message), + ToolResultEvent( + { + "toolUseId": "1", + "status": "error", + "content": [{"text": cancel_message}], + }, + ), + ] + assert tru_events == exp_events + + tru_results = tool_results + exp_results = [exp_events[-1].tool_result] + assert tru_results == exp_results diff --git a/tests_integ/tools/executors/conftest.py b/tests_integ/tools/executors/conftest.py new file mode 100644 index 000000000..c8e7fed95 --- /dev/null +++ b/tests_integ/tools/executors/conftest.py @@ -0,0 +1,15 @@ +import pytest + +from strands.hooks import BeforeToolCallEvent, HookProvider + + +@pytest.fixture +def cancel_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeToolCallEvent, self.cancel) + + def cancel(self, event): + event.cancel_tool = "cancelled tool call" + + return Hook() diff --git a/tests_integ/tools/executors/test_concurrent.py b/tests_integ/tools/executors/test_concurrent.py index 27dd468e0..48653af9c 100644 --- a/tests_integ/tools/executors/test_concurrent.py +++ b/tests_integ/tools/executors/test_concurrent.py @@ -1,4 +1,5 @@ import asyncio +import json import pytest @@ -59,3 +60,18 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events): {"name": "time_tool", "event": "end"}, ] assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_agent_stream_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events): + agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook]) + + exp_message = "cancelled tool call" + tru_message = "" + async for event in agent.stream_async("What is the time in New York?"): + if "tool_cancel_event" in event: + tru_message = event["tool_cancel_event"]["message"] + + assert tru_message == exp_message + assert len(tool_events) == 0 + assert exp_message in json.dumps(agent.messages) diff --git a/tests_integ/tools/executors/test_sequential.py b/tests_integ/tools/executors/test_sequential.py index 82fc51a59..d959222d4 100644 --- a/tests_integ/tools/executors/test_sequential.py +++ b/tests_integ/tools/executors/test_sequential.py @@ -1,4 +1,5 @@ import asyncio +import json import pytest @@ -59,3 +60,18 @@ async def test_agent_invoke_async_tool_executor(agent, tool_events): {"name": "weather_tool", "event": "end"}, ] assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_agent_stream_async_tool_executor_cancelled(cancel_hook, tool_executor, time_tool, tool_events): + agent = Agent(tools=[time_tool], tool_executor=tool_executor, hooks=[cancel_hook]) + + exp_message = "cancelled tool call" + tru_message = "" + async for event in agent.stream_async("What is the time in New York?"): + if "tool_cancel_event" in event: + tru_message = event["tool_cancel_event"]["message"] + + assert tru_message == exp_message + assert len(tool_events) == 0 + assert exp_message in json.dumps(agent.messages) From c2108f612d0f5afae8cedcad563999b5e07ff3d7 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Sat, 4 Oct 2025 17:37:48 -0400 Subject: [PATCH 10/13] fix(telemetry): removed double serialization for events (#977) --- src/strands/telemetry/tracer.py | 16 ++++++---------- tests/strands/telemetry/test_tracer.py | 16 ++++++++-------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index b39de27ea..7cd2d0e7b 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -307,7 +307,7 @@ def end_model_invoke_span( [ { "role": message["role"], - "parts": [{"type": "text", "content": serialize(message["content"])}], + "parts": [{"type": "text", "content": message["content"]}], "finish_reason": str(stop_reason), } ] @@ -362,7 +362,7 @@ def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None "type": "tool_call", "name": tool["name"], "id": tool["toolUseId"], - "arguments": [{"content": serialize(tool["input"])}], + "arguments": [{"content": tool["input"]}], } ], } @@ -417,7 +417,7 @@ def end_tool_call_span( { "type": "tool_call_response", "id": tool_result.get("toolUseId", ""), - "result": serialize(tool_result.get("content")), + "result": tool_result.get("content"), } ], } @@ -504,7 +504,7 @@ def end_event_loop_cycle_span( [ { "role": tool_result_message["role"], - "parts": [{"type": "text", "content": serialize(tool_result_message["content"])}], + "parts": [{"type": "text", "content": tool_result_message["content"]}], } ] ) @@ -640,11 +640,7 @@ def start_multiagent_span( self._add_event( span, "gen_ai.client.inference.operation.details", - { - "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": content}]}] - ) - }, + {"gen_ai.input.messages": serialize([{"role": "user", "parts": [{"type": "text", "content": task}]}])}, ) else: self._add_event( @@ -722,7 +718,7 @@ def _add_event_messages(self, span: Span, messages: Messages) -> None: input_messages: list = [] for message in messages: input_messages.append( - {"role": message["role"], "parts": [{"type": "text", "content": serialize(message["content"])}]} + {"role": message["role"], "parts": [{"type": "text", "content": message["content"]}]} ) self._add_event( span, "gen_ai.client.inference.operation.details", {"gen_ai.input.messages": serialize(input_messages)} diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index eed060294..4e9872100 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -191,7 +191,7 @@ def test_start_model_invoke_span_latest_conventions(mock_tracer): [ { "role": messages[0]["role"], - "parts": [{"type": "text", "content": serialize(messages[0]["content"])}], + "parts": [{"type": "text", "content": messages[0]["content"]}], } ] ) @@ -249,7 +249,7 @@ def test_end_model_invoke_span_latest_conventions(mock_span): [ { "role": "assistant", - "parts": [{"type": "text", "content": serialize(message["content"])}], + "parts": [{"type": "text", "content": message["content"]}], "finish_reason": "end_turn", } ] @@ -318,7 +318,7 @@ def test_start_tool_call_span_latest_conventions(mock_tracer): "type": "tool_call", "name": tool["name"], "id": tool["toolUseId"], - "arguments": [{"content": serialize(tool["input"])}], + "arguments": [{"content": tool["input"]}], } ], } @@ -398,7 +398,7 @@ def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer) "gen_ai.client.inference.operation.details", attributes={ "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": '[{"text": "Original Task: foo bar"}]'}]}] + [{"role": "user", "parts": [{"type": "text", "content": [{"text": "Original Task: foo bar"}]}]}] ) }, ) @@ -502,7 +502,7 @@ def test_end_tool_call_span_latest_conventions(mock_span): { "type": "tool_call_response", "id": tool_result.get("toolUseId", ""), - "result": serialize(tool_result.get("content")), + "result": tool_result.get("content"), } ], } @@ -559,7 +559,7 @@ def test_start_event_loop_cycle_span_latest_conventions(mock_tracer): "gen_ai.client.inference.operation.details", attributes={ "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": serialize(messages[0]["content"])}]}] + [{"role": "user", "parts": [{"type": "text", "content": messages[0]["content"]}]}] ) }, ) @@ -601,7 +601,7 @@ def test_end_event_loop_cycle_span_latest_conventions(mock_span): [ { "role": "assistant", - "parts": [{"type": "text", "content": serialize(tool_result_message["content"])}], + "parts": [{"type": "text", "content": tool_result_message["content"]}], } ] ) @@ -676,7 +676,7 @@ def test_start_agent_span_latest_conventions(mock_tracer): "gen_ai.client.inference.operation.details", attributes={ "gen_ai.input.messages": serialize( - [{"role": "user", "parts": [{"type": "text", "content": '[{"text": "test prompt"}]'}]}] + [{"role": "user", "parts": [{"type": "text", "content": [{"text": "test prompt"}]}]}] ) }, ) From d584ad21f05889f32bea5f05108cc2d8f0ebd11c Mon Sep 17 00:00:00 2001 From: ratish <114130421+Ratish1@users.noreply.github.com> Date: Tue, 7 Oct 2025 22:43:53 +0400 Subject: [PATCH 11/13] fix(litellm): map LiteLLM context-window errors to ContextWindowOverflowException (#994) --- src/strands/models/litellm.py | 31 +++++++++++++++++++++------- tests/strands/models/test_litellm.py | 12 +++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 005eed3df..1763f5dec 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -8,11 +8,13 @@ from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast import litellm +from litellm.exceptions import ContextWindowExceededError from litellm.utils import supports_response_schema from pydantic import BaseModel from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys @@ -135,7 +137,11 @@ async def stream( logger.debug("request=<%s>", request) logger.debug("invoking model") - response = await litellm.acompletion(**self.client_args, **request) + try: + response = await litellm.acompletion(**self.client_args, **request) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow") + raise ContextWindowOverflowException(e) from e logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) @@ -205,15 +211,24 @@ async def structured_output( Yields: Model events with the last being the structured output. """ - if not supports_response_schema(self.get_config()["model_id"]): + supports_schema = supports_response_schema(self.get_config()["model_id"]) + + # If the provider does not support response schemas, we cannot reliably parse structured output. + # In that case we must not call the provider and must raise the documented ValueError. + if not supports_schema: raise ValueError("Model does not support response_format") - response = await litellm.acompletion( - **self.client_args, - model=self.get_config()["model_id"], - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], - response_format=output_model, - ) + # For providers that DO support response schemas, call litellm and map context-window errors. + try: + response = await litellm.acompletion( + **self.client_args, + model=self.get_config()["model_id"], + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], + response_format=output_model, + ) + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow in structured_output") + raise ContextWindowOverflowException(e) from e if len(response.choices) > 1: raise ValueError("Multiple choices found in the response.") diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index bc81fc819..776ae7bae 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -3,9 +3,11 @@ import pydantic import pytest +from litellm.exceptions import ContextWindowExceededError import strands from strands.models.litellm import LiteLLMModel +from strands.types.exceptions import ContextWindowOverflowException @pytest.fixture @@ -332,3 +334,13 @@ def test_tool_choice_none_no_warning(model, messages, captured_warnings): model.format_request(messages, tool_choice=None) assert len(captured_warnings) == 0 + + +@pytest.mark.asyncio +async def test_context_window_maps_to_typed_exception(litellm_acompletion, model): + """Test that a typed ContextWindowExceededError is mapped correctly.""" + litellm_acompletion.side_effect = ContextWindowExceededError(message="test error", model="x", llm_provider="y") + + with pytest.raises(ContextWindowOverflowException): + async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]): + pass From d6a1b5ba8346ad01d0b5ed1a3a3bcecdfde8be56 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 7 Oct 2025 16:02:01 -0400 Subject: [PATCH 12/13] feat: add more tests and adjust invocation_state dic structure --- src/strands/agent/agent.py | 7 +++--- tests/strands/agent/test_agent.py | 42 +++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index d224658ea..ee21c5c84 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -13,6 +13,7 @@ import json import logging import random +import warnings from concurrent.futures import ThreadPoolExecutor from typing import ( Any, @@ -574,16 +575,16 @@ async def stream_async( """ merged_state = {} if kwargs: - logger.warning("`**kwargs` parameter is deprecated, use `invocation_state` instead.") + warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2) merged_state.update(kwargs) if invocation_state is not None: merged_state["invocation_state"] = invocation_state else: if invocation_state is not None: - merged_state["invocation_state"] = invocation_state + merged_state = invocation_state # Get callback handler from merged state or use default - invocation_state_dict = merged_state.get("invocation_state") or {} + invocation_state_dict = merged_state.get("invocation_state", {}) callback_handler = invocation_state_dict.get( "callback_handler", merged_state.get("callback_handler", self.callback_handler) ) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index a5a627632..200584115 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -4,6 +4,7 @@ import os import textwrap import unittest.mock +import warnings from uuid import uuid4 import pytest @@ -1891,3 +1892,44 @@ def test_agent__call__handles_none_invocation_state(mock_model, agent): assert result.message["content"][0]["text"] == "test response" assert result.stop_reason == "end_turn" + + +def test_agent__call__invocation_state_with_kwargs_deprecation_warning(agent, mock_event_loop_cycle): + """Test that kwargs trigger deprecation warning and are merged correctly with invocation_state.""" + + async def check_invocation_state(**kwargs): + invocation_state = kwargs["invocation_state"] + # Should have nested structure when both invocation_state and kwargs are provided + assert invocation_state["invocation_state"] == {"my": "state"} + assert invocation_state["other_kwarg"] == "foobar" + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) + + mock_event_loop_cycle.side_effect = check_invocation_state + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agent("hello!", invocation_state={"my": "state"}, other_kwarg="foobar") + + # Verify deprecation warning was issued + assert len(captured_warnings) == 1 + assert issubclass(captured_warnings[0].category, UserWarning) + assert "`**kwargs` parameter is deprecating, use `invocation_state` instead." in str(captured_warnings[0].message) + + +def test_agent__call__invocation_state_only_no_warning(agent, mock_event_loop_cycle): + """Test that using only invocation_state does not trigger warning and passes state directly.""" + + async def check_invocation_state(**kwargs): + invocation_state = kwargs["invocation_state"] + + assert invocation_state["my"] == "state" + assert "agent" in invocation_state + yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}) + + mock_event_loop_cycle.side_effect = check_invocation_state + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agent("hello!", invocation_state={"my": "state"}) + + assert len(captured_warnings) == 0 From 45ef6ce510fd96d5d31602420b1e75793c0b65cf Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Wed, 8 Oct 2025 13:21:50 -0400 Subject: [PATCH 13/13] Apply suggestion from @Unshure Co-authored-by: Nick Clegg --- src/strands/agent/agent.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index ee21c5c84..8607a2601 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -583,11 +583,9 @@ async def stream_async( if invocation_state is not None: merged_state = invocation_state - # Get callback handler from merged state or use default - invocation_state_dict = merged_state.get("invocation_state", {}) - callback_handler = invocation_state_dict.get( - "callback_handler", merged_state.get("callback_handler", self.callback_handler) - ) + callback_handler = self.callback_handler + if kwargs: + callback_handler = kwargs.get("callback_handler", self.callback_handler) # Process input and get message to add (if any) messages = self._convert_prompt_to_messages(prompt)