Skip to content

Commit e692133

Browse files
authored
bidi - fix mypy errors (#1308)
1 parent 6543097 commit e692133

File tree

4 files changed

+39
-33
lines changed

4 files changed

+39
-33
lines changed

src/strands/tools/_caller.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,10 @@ async def acall() -> ToolResult:
106106

107107
tool_result = run_async(acall)
108108

109-
# Apply conversation management if agent supports it (traditional agents)
110-
if hasattr(self._agent, "conversation_manager"):
109+
# TODO: https://github.com/strands-agents/sdk-python/issues/1311
110+
from ..agent import Agent
111+
112+
if isinstance(self._agent, Agent):
111113
self._agent.conversation_manager.apply_management(self._agent)
112114

113115
return tool_result

src/strands/tools/executors/_executor.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,19 @@ async def _invoke_before_tool_call_hook(
4949
invocation_state: dict[str, Any],
5050
) -> tuple[BeforeToolCallEvent | BidiBeforeToolCallEvent, list[Interrupt]]:
5151
"""Invoke the appropriate before tool call hook based on agent type."""
52-
event_cls = BeforeToolCallEvent if ToolExecutor._is_agent(agent) else BidiBeforeToolCallEvent
53-
return await agent.hooks.invoke_callbacks_async(
54-
event_cls(
55-
agent=agent,
56-
selected_tool=tool_func,
57-
tool_use=tool_use,
58-
invocation_state=invocation_state,
59-
)
52+
kwargs = {
53+
"selected_tool": tool_func,
54+
"tool_use": tool_use,
55+
"invocation_state": invocation_state,
56+
}
57+
event = (
58+
BeforeToolCallEvent(agent=cast("Agent", agent), **kwargs)
59+
if ToolExecutor._is_agent(agent)
60+
else BidiBeforeToolCallEvent(agent=cast("BidiAgent", agent), **kwargs)
6061
)
6162

63+
return await agent.hooks.invoke_callbacks_async(event)
64+
6265
@staticmethod
6366
async def _invoke_after_tool_call_hook(
6467
agent: "Agent | BidiAgent",
@@ -70,19 +73,22 @@ async def _invoke_after_tool_call_hook(
7073
cancel_message: str | None = None,
7174
) -> tuple[AfterToolCallEvent | BidiAfterToolCallEvent, list[Interrupt]]:
7275
"""Invoke the appropriate after tool call hook based on agent type."""
73-
event_cls = AfterToolCallEvent if ToolExecutor._is_agent(agent) else BidiAfterToolCallEvent
74-
return await agent.hooks.invoke_callbacks_async(
75-
event_cls(
76-
agent=agent,
77-
selected_tool=selected_tool,
78-
tool_use=tool_use,
79-
invocation_state=invocation_state,
80-
result=result,
81-
exception=exception,
82-
cancel_message=cancel_message,
83-
)
76+
kwargs = {
77+
"selected_tool": selected_tool,
78+
"tool_use": tool_use,
79+
"invocation_state": invocation_state,
80+
"result": result,
81+
"exception": exception,
82+
"cancel_message": cancel_message,
83+
}
84+
event = (
85+
AfterToolCallEvent(agent=cast("Agent", agent), **kwargs)
86+
if ToolExecutor._is_agent(agent)
87+
else BidiAfterToolCallEvent(agent=cast("BidiAgent", agent), **kwargs)
8488
)
8589

90+
return await agent.hooks.invoke_callbacks_async(event)
91+
8692
@staticmethod
8793
async def _stream(
8894
agent: "Agent | BidiAgent",
@@ -247,7 +253,7 @@ async def _stream(
247253

248254
@staticmethod
249255
async def _stream_with_trace(
250-
agent: "Agent | BidiAgent",
256+
agent: "Agent",
251257
tool_use: ToolUse,
252258
tool_results: list[ToolResult],
253259
cycle_trace: Trace,
@@ -259,7 +265,7 @@ async def _stream_with_trace(
259265
"""Execute tool with tracing and metrics collection.
260266
261267
Args:
262-
agent: The agent (Agent or BidiAgent) for which the tool is being executed.
268+
agent: The agent for which the tool is being executed.
263269
tool_use: Metadata and inputs for the tool to be executed.
264270
tool_results: List of tool results from each tool execution.
265271
cycle_trace: Trace object for the current event loop cycle.
@@ -308,7 +314,7 @@ async def _stream_with_trace(
308314
# pragma: no cover
309315
def _execute(
310316
self,
311-
agent: "Agent | BidiAgent",
317+
agent: "Agent",
312318
tool_uses: list[ToolUse],
313319
tool_results: list[ToolResult],
314320
cycle_trace: Trace,
@@ -319,7 +325,7 @@ def _execute(
319325
"""Execute the given tools according to this executor's strategy.
320326
321327
Args:
322-
agent: The agent (Agent or BidiAgent) for which tools are being executed.
328+
agent: The agent for which tools are being executed.
323329
tool_uses: Metadata and inputs for the tools to be executed.
324330
tool_results: List of tool results from each tool execution.
325331
cycle_trace: Trace object for the current event loop cycle.

src/strands/tools/executors/concurrent.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
if TYPE_CHECKING: # pragma: no cover
1414
from ...agent import Agent
15-
from ...experimental.bidi import BidiAgent
1615
from ..structured_output._structured_output_context import StructuredOutputContext
1716

1817

@@ -22,7 +21,7 @@ class ConcurrentToolExecutor(ToolExecutor):
2221
@override
2322
async def _execute(
2423
self,
25-
agent: "Agent | BidiAgent",
24+
agent: "Agent",
2625
tool_uses: list[ToolUse],
2726
tool_results: list[ToolResult],
2827
cycle_trace: Trace,
@@ -33,7 +32,7 @@ async def _execute(
3332
"""Execute tools concurrently.
3433
3534
Args:
36-
agent: The agent (Agent or BidiAgent) for which tools are being executed.
35+
agent: The agent for which tools are being executed.
3736
tool_uses: Metadata and inputs for the tools to be executed.
3837
tool_results: List of tool results from each tool execution.
3938
cycle_trace: Trace object for the current event loop cycle.
@@ -79,7 +78,7 @@ async def _execute(
7978

8079
async def _task(
8180
self,
82-
agent: "Agent | BidiAgent",
81+
agent: "Agent",
8382
tool_use: ToolUse,
8483
tool_results: list[ToolResult],
8584
cycle_trace: Trace,
@@ -94,7 +93,7 @@ async def _task(
9493
"""Execute a single tool and put results in the task queue.
9594
9695
Args:
97-
agent: The agent (Agent or BidiAgent) executing the tool.
96+
agent: The agent executing the tool.
9897
tool_use: Tool use metadata and inputs.
9998
tool_results: List of tool results from each tool execution.
10099
cycle_trace: Trace object for the current event loop cycle.

src/strands/tools/executors/sequential.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
if TYPE_CHECKING: # pragma: no cover
1313
from ...agent import Agent
14-
from ...experimental.bidi import BidiAgent
1514
from ..structured_output._structured_output_context import StructuredOutputContext
1615

1716

@@ -21,7 +20,7 @@ class SequentialToolExecutor(ToolExecutor):
2120
@override
2221
async def _execute(
2322
self,
24-
agent: "Agent | BidiAgent",
23+
agent: "Agent",
2524
tool_uses: list[ToolUse],
2625
tool_results: list[ToolResult],
2726
cycle_trace: Trace,
@@ -34,7 +33,7 @@ async def _execute(
3433
Breaks early if an interrupt is raised by the user.
3534
3635
Args:
37-
agent: The agent (Agent or BidiAgent) for which tools are being executed.
36+
agent: The agent for which tools are being executed.
3837
tool_uses: Metadata and inputs for the tools to be executed.
3938
tool_results: List of tool results from each tool execution.
4039
cycle_trace: Trace object for the current event loop cycle.

0 commit comments

Comments
 (0)