From 9602938aa0b7dcc996cc1dc432c8322a9dc76659 Mon Sep 17 00:00:00 2001 From: Liat Iusim Date: Wed, 19 Nov 2025 16:02:30 +0200 Subject: [PATCH] fix(concurrent-invocations): added protection from concurrent invocations to the same agent instance --- src/strands/agent/agent.py | 18 ++++++--- tests/strands/agent/test_agent.py | 10 +++++ tests_integ/test_stream_agent.py | 66 ++++++++++++++++++++++++++++++- 3 files changed, 87 insertions(+), 7 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e13b9f6d8..26953dd47 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -9,6 +9,7 @@ 2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ +import asyncio import json import logging import random @@ -293,6 +294,7 @@ def __init__( self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description + self._invocation_lock = asyncio.Lock() # If not provided, create a new PrintingCallbackHandler instance # If explicitly set to None, use null_callback_handler @@ -494,13 +496,17 @@ async def invoke_async( - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - events = self.stream_async( - prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs - ) - async for event in events: - _ = event + if self._invocation_lock.locked(): + raise RuntimeError("Agent is already processing a request. Concurrent invocations are not supported.") + + async with self._invocation_lock: + events = self.stream_async( + prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + ) + async for event in events: + _ = event - return cast(AgentResult, event["result"]) + return cast(AgentResult, event["result"]) def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index d04f57948..e1f5fbc75 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -783,6 +783,16 @@ async def test_agent__call__in_async_context(mock_model, agent, agenerator): assert tru_message == exp_message +@pytest.mark.asyncio +async def test_agent_parallel_invocations(): + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) + agent = Agent(model=model) + + async with agent._invocation_lock: + with pytest.raises(RuntimeError, match="Concurrent invocations are not supported"): + await agent.invoke_async("test") + + @pytest.mark.asyncio async def test_agent_invoke_async(mock_model, agent, agenerator): mock_model.mock_stream.return_value = agenerator( diff --git a/tests_integ/test_stream_agent.py b/tests_integ/test_stream_agent.py index 01f203390..20a057c52 100644 --- a/tests_integ/test_stream_agent.py +++ b/tests_integ/test_stream_agent.py @@ -4,13 +4,22 @@ """ import logging +import threading +import time -from strands import Agent +from strands import Agent, tool logging.getLogger("strands").setLevel(logging.DEBUG) logging.basicConfig(format="%(levelname)s | %(name)s | %(message)s", handlers=[logging.StreamHandler()]) +@tool +def wait(seconds: int) -> None: + """Waits x seconds based on the user input. + Seconds - seconds to wait""" + time.sleep(seconds) + + class ToolCountingCallbackHandler: def __init__(self): self.tool_count = 0 @@ -68,3 +77,58 @@ def test_basic_interaction(): agent("Tell me a short joke from your general knowledge") print("\nBasic Interaction Complete") + + +def test_parallel_async_interaction(): + """Test that concurrent agent invocations are not allowed""" + + # Initialize agent + agent = Agent( + callback_handler=ToolCountingCallbackHandler().callback_handler, load_tools_from_directory=False, tools=[wait] + ) + + # Track results from both threads + results = {"thread1": None, "thread2": None, "exception": None} + + def invoke_agent_1(): + """First invocation - should succeed""" + try: + result = agent("wait 5 seconds") + results["thread1"] = result + except Exception as e: + results["thread1"] = e + + def invoke_agent_2(): + """Second invocation - should fail with exception""" + try: + result = agent("wait 5 seconds") + results["thread2"] = result + except Exception as e: + results["thread2"] = e + results["exception"] = e + + # Start first invocation + thread1 = threading.Thread(target=invoke_agent_1) + thread1.start() + + # Give it time to start and begin waiting + time.sleep(1) + + # Try second invocation while first is still running + thread2 = threading.Thread(target=invoke_agent_2) + thread2.start() + + thread1.join() + thread2.join() + + # Assertions + assert results["thread1"] is not None, "First invocation should complete" + assert not isinstance(results["thread1"], Exception), "First invocation should succeed" + + assert results["exception"] is not None, "Second invocation should throw exception" + assert isinstance(results["thread2"], Exception), "Second invocation should fail" + + expected_message = "Agent is already processing a request. Concurrent invocations are not supported" + assert expected_message in str(results["thread2"]), ( + f"Exception message should contain '{expected_message}', but got: {str(results['thread2'])}" + )