Skip to content

Commit a72cee2

Browse files
zastrowmWorkshop Participant
authored andcommitted
chore: Remove agent.tool_config and update usages to use tool_specs (#388)
Agent.tool_config is a configuration object which serves as a wrapper to tool_specs and nothing else. We actually don't use the toolChoice at all anywhere, and the `Tool` wrapper also was a container that served no purpose as everywhere we used the tools, we wanted the ToolSpec anyways. Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com>
1 parent 91ef5cd commit a72cee2

File tree

8 files changed

+44
-47
lines changed

8 files changed

+44
-47
lines changed

src/strands/agent/agent.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ..types.content import ContentBlock, Message, Messages
3232
from ..types.exceptions import ContextWindowOverflowException
3333
from ..types.models import Model
34-
from ..types.tools import ToolConfig, ToolResult, ToolUse
34+
from ..types.tools import ToolResult, ToolUse
3535
from ..types.traces import AttributeValue
3636
from .agent_result import AgentResult
3737
from .conversation_manager import (
@@ -335,15 +335,6 @@ def tool_names(self) -> list[str]:
335335
all_tools = self.tool_registry.get_all_tools_config()
336336
return list(all_tools.keys())
337337

338-
@property
339-
def tool_config(self) -> ToolConfig:
340-
"""Get the tool configuration for this agent.
341-
342-
Returns:
343-
The complete tool configuration.
344-
"""
345-
return self.tool_registry.initialize_tool_config()
346-
347338
def __del__(self) -> None:
348339
"""Clean up resources when Agent is garbage collected.
349340

src/strands/event_loop/event_loop.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
import time
1313
import uuid
14-
from typing import TYPE_CHECKING, Any, AsyncGenerator
14+
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
1515

1616
from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
1717
from ..experimental.hooks.registry import get_registry
@@ -21,7 +21,7 @@
2121
from ..types.content import Message
2222
from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
2323
from ..types.streaming import Metrics, StopReason
24-
from ..types.tools import ToolGenerator, ToolResult, ToolUse
24+
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
2525
from .message_processor import clean_orphaned_empty_tool_uses
2626
from .streaming import stream_messages
2727

@@ -112,10 +112,12 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
112112
model_id=model_id,
113113
)
114114

115+
tool_specs = agent.tool_registry.get_all_tool_specs()
116+
115117
try:
116118
# TODO: To maintain backwards compatibility, we need to combine the stream event with kwargs before yielding
117119
# to the callback handler. This will be revisited when migrating to strongly typed events.
118-
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, agent.tool_config):
120+
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
119121
if "callback" in event:
120122
yield {"callback": {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}}
121123

@@ -172,12 +174,6 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
172174

173175
# If the model is requesting to use tools
174176
if stop_reason == "tool_use":
175-
if agent.tool_config is None:
176-
raise EventLoopException(
177-
Exception("Model requested tool use but no tool config provided"),
178-
kwargs["request_state"],
179-
)
180-
181177
# Handle tool execution
182178
events = _handle_tool_execution(
183179
stop_reason,
@@ -285,7 +281,10 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
285281
"model": agent.model,
286282
"system_prompt": agent.system_prompt,
287283
"messages": agent.messages,
288-
"tool_config": agent.tool_config,
284+
"tool_config": ToolConfig( # for backwards compatability
285+
tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()],
286+
toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}),
287+
),
289288
}
290289
)
291290

src/strands/event_loop/streaming.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
StreamEvent,
2020
Usage,
2121
)
22-
from ..types.tools import ToolConfig, ToolUse
22+
from ..types.tools import ToolSpec, ToolUse
2323

2424
logger = logging.getLogger(__name__)
2525

@@ -304,24 +304,23 @@ async def stream_messages(
304304
model: Model,
305305
system_prompt: Optional[str],
306306
messages: Messages,
307-
tool_config: Optional[ToolConfig],
307+
tool_specs: list[ToolSpec],
308308
) -> AsyncGenerator[dict[str, Any], None]:
309309
"""Streams messages to the model and processes the response.
310310
311311
Args:
312312
model: Model provider.
313313
system_prompt: The system prompt to send.
314314
messages: List of messages to send.
315-
tool_config: Configuration for the tools to use.
315+
tool_specs: The list of tool specs.
316316
317317
Returns:
318318
The reason for stopping, the final message, and the usage metrics
319319
"""
320320
logger.debug("model=<%s> | streaming messages", model)
321321

322322
messages = remove_blank_messages_content_text(messages)
323-
tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None
324323

325-
chunks = model.converse(messages, tool_specs, system_prompt)
324+
chunks = model.converse(messages, tool_specs if tool_specs else None, system_prompt)
326325
async for event in process_stream(chunks, messages):
327326
yield event

src/strands/tools/registry.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from strands.tools.decorator import DecoratedFunctionTool
1919

20-
from ..types.tools import AgentTool, Tool, ToolChoice, ToolChoiceAuto, ToolConfig, ToolSpec
20+
from ..types.tools import AgentTool, ToolSpec
2121
from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec
2222

2323
logger = logging.getLogger(__name__)
@@ -472,20 +472,15 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None:
472472
for tool_name, error in tool_import_errors.items():
473473
logger.debug("tool_name=<%s> | import error | %s", tool_name, error)
474474

475-
def initialize_tool_config(self) -> ToolConfig:
476-
"""Initialize tool configuration from tool handler with optional filtering.
475+
def get_all_tool_specs(self) -> list[ToolSpec]:
476+
"""Get all the tool specs for all tools in this registry..
477477
478478
Returns:
479-
Tool config.
479+
A list of ToolSpecs.
480480
"""
481481
all_tools = self.get_all_tools_config()
482-
483-
tools: List[Tool] = [{"toolSpec": tool_spec} for tool_spec in all_tools.values()]
484-
485-
return ToolConfig(
486-
tools=tools,
487-
toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}),
488-
)
482+
tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()]
483+
return tools
489484

490485
def validate_tool_spec(self, tool_spec: ToolSpec) -> None:
491486
"""Validate tool specification against required schema.

tests/strands/agent/test_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_impor
180180

181181
agent = Agent(tools=[tool_decorated, tool_module, tool_imported])
182182

183-
tru_tool_names = sorted(tool_spec["toolSpec"]["name"] for tool_spec in agent.tool_config["tools"])
183+
tru_tool_names = sorted(tool_spec["name"] for tool_spec in agent.tool_registry.get_all_tool_specs())
184184
exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]
185185

186186
assert tru_tool_names == exp_tool_names
@@ -191,7 +191,7 @@ def test_agent__init__tool_loader_dict(tool_module, tool_registry):
191191

192192
agent = Agent(tools=[{"name": "tool_module", "path": tool_module}])
193193

194-
tru_tool_names = sorted(tool_spec["toolSpec"]["name"] for tool_spec in agent.tool_config["tools"])
194+
tru_tool_names = sorted(tool_spec["name"] for tool_spec in agent.tool_registry.get_all_tool_specs())
195195
exp_tool_names = ["tool_module"]
196196

197197
assert tru_tool_names == exp_tool_names

tests/strands/event_loop/test_event_loop.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ def messages():
3535
return [{"role": "user", "content": [{"text": "Hello"}]}]
3636

3737

38-
@pytest.fixture
39-
def tool_config():
40-
return {"tools": [{"toolSpec": {"name": "tool_for_testing"}}], "toolChoice": {"auto": {}}}
41-
42-
4338
@pytest.fixture
4439
def tool_registry():
4540
return ToolRegistry()
@@ -116,13 +111,12 @@ def hook_provider(hook_registry):
116111

117112

118113
@pytest.fixture
119-
def agent(model, system_prompt, messages, tool_config, tool_registry, thread_pool, hook_registry):
114+
def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry):
120115
mock = unittest.mock.Mock(name="agent")
121116
mock.config.cache_points = []
122117
mock.model = model
123118
mock.system_prompt = system_prompt
124119
mock.messages = messages
125-
mock.tool_config = tool_config
126120
mock.tool_registry = tool_registry
127121
mock.thread_pool = thread_pool
128122
mock.event_loop_metrics = EventLoopMetrics()
@@ -298,6 +292,7 @@ async def test_event_loop_cycle_tool_result(
298292
system_prompt,
299293
messages,
300294
tool_stream,
295+
tool_registry,
301296
agenerator,
302297
alist,
303298
):
@@ -353,7 +348,7 @@ async def test_event_loop_cycle_tool_result(
353348
},
354349
{"role": "assistant", "content": [{"text": "test text"}]},
355350
],
356-
[{"name": "tool_for_testing"}],
351+
tool_registry.get_all_tool_specs(),
357352
"p1",
358353
)
359354

tests/strands/event_loop/test_streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ async def test_stream_messages(agenerator, alist):
589589
mock_model,
590590
system_prompt="test prompt",
591591
messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}],
592-
tool_config=None,
592+
tool_specs=None,
593593
)
594594

595595
tru_events = await alist(stream)

tests/strands/tools/test_registry.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88

9+
import strands
910
from strands.tools import PythonAgentTool
1011
from strands.tools.decorator import DecoratedFunctionTool, tool
1112
from strands.tools.registry import ToolRegistry
@@ -46,6 +47,23 @@ def test_register_tool_with_similar_name_raises():
4647
)
4748

4849

50+
def test_get_all_tool_specs_returns_right_tool_specs():
51+
tool_1 = strands.tool(lambda a: a, name="tool_1")
52+
tool_2 = strands.tool(lambda b: b, name="tool_2")
53+
54+
tool_registry = ToolRegistry()
55+
56+
tool_registry.register_tool(tool_1)
57+
tool_registry.register_tool(tool_2)
58+
59+
tool_specs = tool_registry.get_all_tool_specs()
60+
61+
assert tool_specs == [
62+
tool_1.tool_spec,
63+
tool_2.tool_spec,
64+
]
65+
66+
4967
def test_scan_module_for_tools():
5068
@tool
5169
def tool_function_1(a):

0 commit comments

Comments
 (0)