Skip to content

Commit a135262

Browse files
committed
fix: adjust **kwargs in multiagent primitives
1 parent 45ef6ce commit a135262

File tree

6 files changed

+25
-18
lines changed

6 files changed

+25
-18
lines changed

src/strands/multiagent/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import asyncio
7+
import warnings
78
from abc import ABC, abstractmethod
89
from concurrent.futures import ThreadPoolExecutor
910
from dataclasses import dataclass, field
@@ -85,15 +86,14 @@ class MultiAgentBase(ABC):
8586

8687
@abstractmethod
8788
async def invoke_async(
88-
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
89+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None
8990
) -> MultiAgentResult:
9091
"""Invoke asynchronously.
9192
9293
Args:
9394
task: The task to execute
9495
invocation_state: Additional state/context passed to underlying agents.
9596
Defaults to None to avoid mutable default argument issues.
96-
**kwargs: Additional keyword arguments passed to underlying agents.
9797
"""
9898
raise NotImplementedError("invoke_async not implemented")
9999

@@ -111,8 +111,12 @@ def __call__(
111111
if invocation_state is None:
112112
invocation_state = {}
113113

114+
if kwargs:
115+
invocation_state.update(kwargs)
116+
warnings.warn("`**kwargs` parameter is deprecating, use `invocation_state` instead.", stacklevel=2)
117+
114118
def execute() -> MultiAgentResult:
115-
return asyncio.run(self.invoke_async(task, invocation_state, **kwargs))
119+
return asyncio.run(self.invoke_async(task, invocation_state))
116120

117121
with ThreadPoolExecutor() as executor:
118122
future = executor.submit(execute)

src/strands/multiagent/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,11 +572,11 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
572572
elif isinstance(node.executor, Agent):
573573
if self.node_timeout is not None:
574574
agent_response = await asyncio.wait_for(
575-
node.executor.invoke_async(node_input, **invocation_state),
575+
node.executor.invoke_async(node_input, invocation_state=invocation_state),
576576
timeout=self.node_timeout,
577577
)
578578
else:
579-
agent_response = await node.executor.invoke_async(node_input, **invocation_state)
579+
agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state)
580580

581581
# Extract metrics from agent response
582582
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)

src/strands/multiagent/swarm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,8 +635,7 @@ async def _execute_node(
635635
# Execute node
636636
result = None
637637
node.reset_executor_state()
638-
# Unpacking since this is the agent class. Other executors should not unpack
639-
result = await node.executor.invoke_async(node_input, **invocation_state)
638+
result = await node.executor.invoke_async(node_input, invocation_state=invocation_state)
640639

641640
execution_time = round((time.time() - start_time) * 1000)
642641

tests/strands/multiagent/test_base.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,23 +153,23 @@ class TestMultiAgent(MultiAgentBase):
153153
def __init__(self):
154154
self.invoke_async_called = False
155155
self.received_task = None
156-
self.received_kwargs = None
156+
self.received_invocation_state = None
157157

158-
async def invoke_async(self, task, invocation_state, **kwargs):
158+
async def invoke_async(self, task, invocation_state=None):
159159
self.invoke_async_called = True
160160
self.received_task = task
161-
self.received_kwargs = kwargs
161+
self.received_invocation_state = invocation_state
162162
return MultiAgentResult(
163163
status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)}
164164
)
165165

166166
agent = TestMultiAgent()
167167

168168
# Test with string task
169-
result = agent("test task", param1="value1", param2="value2")
169+
result = agent("test task", param1="value1", param2="value2", invocation_state={"value3": "value4"})
170170

171171
assert agent.invoke_async_called
172172
assert agent.received_task == "test task"
173-
assert agent.received_kwargs == {"param1": "value1", "param2": "value2"}
173+
assert agent.received_invocation_state == {"param1": "value1", "param2": "value2", "value3": "value4"}
174174
assert isinstance(result, MultiAgentResult)
175175
assert result.status == Status.COMPLETED

tests/strands/multiagent/test_graph.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span):
310310
result = await graph.invoke_async([{"text": "Original task"}])
311311

312312
# Verify entry node was called with original task
313-
entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}])
313+
entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}], invocation_state={})
314314
assert result.status == Status.COMPLETED
315315
mock_strands_tracer.start_multiagent_span.assert_called()
316316
mock_use_span.assert_called_once()
@@ -906,7 +906,7 @@ def __init__(self, name):
906906
self._session_manager = None
907907
self.hooks = HookRegistry()
908908

909-
async def invoke_async(self, input_data):
909+
async def invoke_async(self, input_data, invocation_state=None):
910910
# Increment execution count in state
911911
count = self.state.get("execution_count") or 0
912912
self.state.set("execution_count", count + 1)
@@ -1300,7 +1300,9 @@ async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span):
13001300
test_invocation_state = {"custom_param": "test_value", "another_param": 42}
13011301
result = await graph.invoke_async("Test kwargs passing", test_invocation_state)
13021302

1303-
kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state)
1303+
kwargs_agent.invoke_async.assert_called_once_with(
1304+
[{"text": "Test kwargs passing"}], invocation_state=test_invocation_state
1305+
)
13041306
assert result.status == Status.COMPLETED
13051307

13061308

@@ -1335,5 +1337,7 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span):
13351337
test_invocation_state = {"custom_param": "test_value", "another_param": 42}
13361338
result = graph("Test kwargs passing sync", test_invocation_state)
13371339

1338-
kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state)
1340+
kwargs_agent.invoke_async.assert_called_once_with(
1341+
[{"text": "Test kwargs passing sync"}], invocation_state=test_invocation_state
1342+
)
13391343
assert result.status == Status.COMPLETED

tests/strands/multiagent/test_swarm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span):
558558
test_kwargs = {"custom_param": "test_value", "another_param": 42}
559559
result = await swarm.invoke_async("Test kwargs passing", test_kwargs)
560560

561-
assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs
561+
assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs}
562562
assert result.status == Status.COMPLETED
563563

564564

@@ -572,5 +572,5 @@ def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span):
572572
test_kwargs = {"custom_param": "test_value", "another_param": 42}
573573
result = swarm("Test kwargs passing sync", test_kwargs)
574574

575-
assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs
575+
assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs}
576576
assert result.status == Status.COMPLETED

0 commit comments

Comments
 (0)