Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import json
import logging
import random
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Expand Down Expand Up @@ -374,7 +375,9 @@ def tool_names(self) -> list[str]:
all_tools = self.tool_registry.get_all_tools_config()
return list(all_tools.keys())

def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
def __call__(
self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

This method implements the conversational interface with multiple input patterns:
Expand All @@ -389,7 +392,8 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
- list[ContentBlock]: Multi-modal content blocks
- list[Message]: Complete messages with roles
- None: Use existing conversation history
**kwargs: Additional parameters to pass through the event loop.
invocation_state: Additional parameters to pass through the event loop.
**kwargs: Additional parameters to pass through the event loop.[Deprecating]

Returns:
Result object containing:
Expand All @@ -401,13 +405,15 @@ def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
"""

def execute() -> AgentResult:
return asyncio.run(self.invoke_async(prompt, **kwargs))
return asyncio.run(self.invoke_async(prompt, invocation_state=invocation_state, **kwargs))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()

async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
async def invoke_async(
self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.

This method implements the conversational interface with multiple input patterns:
Expand All @@ -422,7 +428,8 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR
- list[ContentBlock]: Multi-modal content blocks
- list[Message]: Complete messages with roles
- None: Use existing conversation history
**kwargs: Additional parameters to pass through the event loop.
invocation_state: Additional parameters to pass through the event loop.
**kwargs: Additional parameters to pass through the event loop.[Deprecating]

Returns:
Result: object containing:
Expand All @@ -432,7 +439,7 @@ async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentR
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
events = self.stream_async(prompt, **kwargs)
events = self.stream_async(prompt, invocation_state=invocation_state, **kwargs)
async for event in events:
_ = event

Expand Down Expand Up @@ -528,9 +535,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))

async def stream_async(
self,
prompt: AgentInput = None,
**kwargs: Any,
self, prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.

Expand All @@ -546,7 +551,8 @@ async def stream_async(
- list[ContentBlock]: Multi-modal content blocks
- list[Message]: Complete messages with roles
- None: Use existing conversation history
**kwargs: Additional parameters to pass to the event loop.
invocation_state: Additional parameters to pass through the event loop.
**kwargs: Additional parameters to pass to the event loop.[Deprecating]

Yields:
An async iterator that yields events. Each event is a dictionary containing
Expand All @@ -567,7 +573,19 @@ async def stream_async(
yield event["data"]
```
"""
callback_handler = kwargs.get("callback_handler", self.callback_handler)
merged_state = {}
if kwargs:
warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2)
merged_state.update(kwargs)
if invocation_state is not None:
merged_state["invocation_state"] = invocation_state
else:
if invocation_state is not None:
merged_state = invocation_state

callback_handler = self.callback_handler
if kwargs:
callback_handler = kwargs.get("callback_handler", self.callback_handler)

# Process input and get message to add (if any)
messages = self._convert_prompt_to_messages(prompt)
Expand All @@ -576,10 +594,10 @@ async def stream_async(

with trace_api.use_span(self.trace_span):
try:
events = self._run_loop(messages, invocation_state=kwargs)
events = self._run_loop(messages, invocation_state=merged_state)

async for event in events:
event.prepare(invocation_state=kwargs)
event.prepare(invocation_state=merged_state)

if event.is_callback_event:
as_dict = event.as_dict()
Expand Down
56 changes: 56 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import textwrap
import unittest.mock
import warnings
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -1877,3 +1878,58 @@ def test_tool(action: str) -> str:
assert '"action": "test_value"' in tool_call_text
assert '"agent"' not in tool_call_text
assert '"extra_param"' not in tool_call_text


def test_agent__call__handles_none_invocation_state(mock_model, agent):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add one more unit test that can see what invocation state is passed to the event loop?

I want an invocation like this:

agent("hello!", invocation_state={"my": "state"}, other_kwarg="foobar")

To result in an invocation state like this:

event_loop_cycle.called_with(mock_agent, invocation_state=
{
  "invocation_state": {
    "my": "state"
  },
  "other_kwarg": "foobar"
})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And similarly,

agent("hello!", invocation_state={"my": "state"})

called with:

event_loop_cycle.called_with(mock_agent, invocation_state={"my": "state"})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ Can you add the test case Mackenzie called out as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, see this test

"""Test that agent handles None invocation_state without AttributeError."""
mock_model.mock_stream.return_value = [
{"contentBlockDelta": {"delta": {"text": "test response"}}},
{"contentBlockStop": {}},
]

# This should not raise AttributeError: 'NoneType' object has no attribute 'get'
result = agent("test", invocation_state=None)

assert result.message["content"][0]["text"] == "test response"
assert result.stop_reason == "end_turn"


def test_agent__call__invocation_state_with_kwargs_deprecation_warning(agent, mock_event_loop_cycle):
"""Test that kwargs trigger deprecation warning and are merged correctly with invocation_state."""

async def check_invocation_state(**kwargs):
invocation_state = kwargs["invocation_state"]
# Should have nested structure when both invocation_state and kwargs are provided
assert invocation_state["invocation_state"] == {"my": "state"}
assert invocation_state["other_kwarg"] == "foobar"
yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})

mock_event_loop_cycle.side_effect = check_invocation_state

with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter("always")
agent("hello!", invocation_state={"my": "state"}, other_kwarg="foobar")

# Verify deprecation warning was issued
assert len(captured_warnings) == 1
assert issubclass(captured_warnings[0].category, UserWarning)
assert "`**kwargs` parameter is deprecating, use `invocation_state` instead." in str(captured_warnings[0].message)


def test_agent__call__invocation_state_only_no_warning(agent, mock_event_loop_cycle):
"""Test that using only invocation_state does not trigger warning and passes state directly."""

async def check_invocation_state(**kwargs):
invocation_state = kwargs["invocation_state"]

assert invocation_state["my"] == "state"
assert "agent" in invocation_state
yield EventLoopStopEvent("stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {})

mock_event_loop_cycle.side_effect = check_invocation_state

with warnings.catch_warnings(record=True) as captured_warnings:
warnings.simplefilter("always")
agent("hello!", invocation_state={"my": "state"})

assert len(captured_warnings) == 0