diff --git a/.github/workflows/lint-toolbox-adk.yaml b/.github/workflows/lint-toolbox-adk.yaml index 6bb8e3bc..4e41544e 100644 --- a/.github/workflows/lint-toolbox-adk.yaml +++ b/.github/workflows/lint-toolbox-adk.yaml @@ -67,12 +67,12 @@ jobs: with: python-version: "3.13" + - name: Install test requirements + run: pip install -e .[test] + - name: Install library requirements run: pip install -r requirements.txt - - name: Install test requirements - run: pip install .[test] - - name: Run linters run: | black --check . diff --git a/.github/workflows/lint-toolbox-core.yaml b/.github/workflows/lint-toolbox-core.yaml index 05994176..100f2e8e 100644 --- a/.github/workflows/lint-toolbox-core.yaml +++ b/.github/workflows/lint-toolbox-core.yaml @@ -67,12 +67,12 @@ jobs: with: python-version: "3.13" + - name: Install test requirements + run: pip install -e .[test] + - name: Install library requirements run: pip install -r requirements.txt - - name: Install test requirements - run: pip install .[test] - - name: Run linters run: | black --check . diff --git a/.github/workflows/lint-toolbox-langchain.yaml b/.github/workflows/lint-toolbox-langchain.yaml index 79d46ada..dae263fa 100644 --- a/.github/workflows/lint-toolbox-langchain.yaml +++ b/.github/workflows/lint-toolbox-langchain.yaml @@ -67,12 +67,12 @@ jobs: with: python-version: "3.13" + - name: Install test requirements + run: pip install -e .[test] + - name: Install library requirements run: pip install -r requirements.txt - - name: Install test requirements - run: pip install .[test] - - name: Run linters run: | black --check . diff --git a/.github/workflows/lint-toolbox-llamaindex.yaml b/.github/workflows/lint-toolbox-llamaindex.yaml index c6f12ea5..7523413a 100644 --- a/.github/workflows/lint-toolbox-llamaindex.yaml +++ b/.github/workflows/lint-toolbox-llamaindex.yaml @@ -67,12 +67,12 @@ jobs: with: python-version: "3.13" + - name: Install test requirements + run: pip install -e .[test] + - name: Install library requirements run: pip install -r requirements.txt - - name: Install test requirements - run: pip install .[test] - - name: Run linters run: | black --check . diff --git a/packages/toolbox-adk/README.md b/packages/toolbox-adk/README.md index bc9b3ea7..581814f1 100644 --- a/packages/toolbox-adk/README.md +++ b/packages/toolbox-adk/README.md @@ -210,6 +210,9 @@ creds = CredentialStrategy.from_adk_credentials(auth_credential, scheme) Some tools may define their own authentication requirements (e.g., Salesforce OAuth, GitHub PAT) via `authSources` in their schema. You can provide a mapping of getters to resolve these tokens at runtime. +> [!TIP] +> Getters can optionally accept the ADK `ToolContext` as a single argument. This enables seamless integration of dynamic, end-user tokens that are tied to the current agent execution state. + ```python async def get_salesforce_token(): # Fetch token from secret manager or reliable source @@ -218,8 +221,9 @@ async def get_salesforce_token(): toolset = ToolboxToolset( server_url="...", auth_token_getters={ - "salesforce-auth": get_salesforce_token, # Async callable - "github-pat": lambda: "my-pat-token" # Sync callable or static lambda + "salesforce-auth": get_salesforce_token, # Async callable + "github-pat": lambda: "my-pat-token", # Sync callable or static lambda + "oauth-user": lambda ctx: ctx.state.get("auth_token") # Dynamic context-aware callable } ) ``` diff --git a/packages/toolbox-adk/integration.cloudbuild.yaml b/packages/toolbox-adk/integration.cloudbuild.yaml index 441b53f1..5b7cddbb 100644 --- a/packages/toolbox-adk/integration.cloudbuild.yaml +++ b/packages/toolbox-adk/integration.cloudbuild.yaml @@ -28,7 +28,8 @@ steps: # Use $$ to escape shell variable for Cloud Build CORE_VERSION=$$(python -c "v={}; exec(open('../toolbox-core/src/toolbox_core/version.py').read(), v); print(v['__version__'])") sed -i "s/toolbox-core==[0-9.]*/toolbox-core==$$CORE_VERSION/g" pyproject.toml - uv pip install -r requirements.txt -e '.[test]' + uv pip install -r requirements.txt + uv pip install -e '.[test]' entrypoint: /bin/bash - id: Run integration tests name: 'python:${_VERSION}' diff --git a/packages/toolbox-adk/src/toolbox_adk/tool.py b/packages/toolbox-adk/src/toolbox_adk/tool.py index 720407bf..59e813b2 100644 --- a/packages/toolbox-adk/src/toolbox_adk/tool.py +++ b/packages/toolbox-adk/src/toolbox_adk/tool.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging -from typing import Any, Awaitable, Callable, Dict, Optional +from typing import Any, Awaitable, Callable, Dict, Mapping, Optional import google.adk.auth.exchanger.oauth2_credential_exchanger as oauth2_credential_exchanger import google.adk.auth.oauth2_credential_util as oauth2_credential_util @@ -68,11 +69,13 @@ def __init__( self, core_tool: CoreToolboxTool, auth_config: Optional[CredentialConfig] = None, + adk_token_getters: Optional[Mapping[str, Any]] = None, ): """ Args: core_tool: The underlying toolbox_core.py tool instance. auth_config: Credential configuration to handle interactive flows. + adk_token_getters: Tool-specific auth token getters. """ # We act as a proxy. # We need to extract metadata from the core tool to satisfy BaseTool's contract. @@ -95,6 +98,7 @@ def __init__( ) self._core_tool = core_tool self._auth_config = auth_config + self._adk_token_getters = adk_token_getters or {} def _param_type_to_schema_type(self, param_type: str) -> Type: type_map = { @@ -260,12 +264,32 @@ async def run_async( "Falling back to request_credential.", exc_info=True, ) - # Fallback to request logic tool_context.request_credential(auth_config_adk) return { "error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in." } + if self._adk_token_getters: + # Pre-filter toolset getters to avoid unused-token errors from the core tool. + # This deferred loop also enables dynamic 1-arity `tool_context` injection. + needed_services = set() + for reqs in self._core_tool._required_authn_params.values(): + needed_services.update(reqs) + needed_services.update(self._core_tool._required_authz_tokens) + + for service, getter in self._adk_token_getters.items(): + if service in needed_services: + sig = inspect.signature(getter) + + if len(sig.parameters) == 1: + bound_getter = lambda t=getter, ctx=tool_context: t(ctx) + else: + bound_getter = getter + + self._core_tool = self._core_tool.add_auth_token_getter( + service, bound_getter + ) + result: Optional[Any] = None error: Optional[Exception] = None @@ -288,4 +312,5 @@ def bind_params(self, bounded_params: Dict[str, Any]) -> "ToolboxTool": return ToolboxTool( core_tool=new_core_tool, auth_config=self._auth_config, + adk_token_getters=self._adk_token_getters, ) diff --git a/packages/toolbox-adk/src/toolbox_adk/toolset.py b/packages/toolbox-adk/src/toolbox_adk/toolset.py index 8c7c340f..7689f96d 100644 --- a/packages/toolbox-adk/src/toolbox_adk/toolset.py +++ b/packages/toolbox-adk/src/toolbox_adk/toolset.py @@ -18,6 +18,7 @@ from google.adk.tools.base_tool import BaseTool from google.adk.tools.base_toolset import BaseToolset from google.adk.tools.tool_context import ToolContext +from toolbox_core.utils import validate_unused_requirements from typing_extensions import override from .client import ToolboxClient @@ -41,7 +42,15 @@ def __init__( ] = None, bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, auth_token_getters: Optional[ - Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]] + Mapping[ + str, + Union[ + Callable[[], str], + Callable[[], Awaitable[str]], + Callable[[ToolContext], str], + Callable[[ToolContext], Awaitable[str]], + ], + ] ] = None, **kwargs: Any, ): @@ -91,7 +100,6 @@ async def get_tools( core_tools = await self.client.load_toolset( self.__toolset_name, bound_params=self.__bound_params or {}, - auth_token_getters=self.__auth_token_getters or {}, ) tools.extend(core_tools) @@ -101,7 +109,6 @@ async def get_tools( core_tool = await self.client.load_tool( name, bound_params=self.__bound_params or {}, - auth_token_getters=self.__auth_token_getters or {}, ) tools.append(core_tool) @@ -110,15 +117,40 @@ async def get_tools( core_tools = await self.client.load_toolset( None, bound_params=self.__bound_params or {}, - auth_token_getters=self.__auth_token_getters or {}, ) tools.extend(core_tools) + # 4. Strictly validate unused toolset auth token getters using core logic + if self.__auth_token_getters: + overall_used_auth_keys = set() + for t in tools: + for reqs in t._required_authn_params.values(): + overall_used_auth_keys.update(reqs) + overall_used_auth_keys.update(t._required_authz_tokens) + + # Generate intuitive name for the error string if a specific toolset wasn't used + validation_name = self.__toolset_name + if not validation_name: + validation_name = ( + ", ".join(self.__tool_names) if self.__tool_names else "default" + ) + + validate_unused_requirements( + provided_auth_keys=set(self.__auth_token_getters.keys()), + provided_bound_keys=set(), + used_auth_keys=overall_used_auth_keys, + used_bound_keys=set(), + name=validation_name, + is_toolset=True, + target_type="list of tools" if not self.__toolset_name else None, + ) + # Wrap all core tools in ToolboxTool return [ ToolboxTool( core_tool=t, auth_config=self.client.credential_config, + adk_token_getters=self.__auth_token_getters, ) for t in tools ] diff --git a/packages/toolbox-adk/tests/integration/test_integration.py b/packages/toolbox-adk/tests/integration/test_integration.py index 524df61e..6fed882e 100644 --- a/packages/toolbox-adk/tests/integration/test_integration.py +++ b/packages/toolbox-adk/tests/integration/test_integration.py @@ -545,7 +545,48 @@ async def test_run_tool_unauth_with_auth(self, auth_token2: str): try: with pytest.raises( ValueError, - match=rf"Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth", + match=rf"Validation failed for list of tools 'get-row-by-id': unused auth tokens could not be applied to any tool: my-test-auth", + ): + await toolset.get_tools() + finally: + await toolset.close() + + async def test_run_multiple_tools_unauth_with_auth(self, auth_token2: str): + """Tests running multiple tools that don't require auth, verifying formatting of tool lists.""" + toolset = ToolboxToolset( + server_url="http://localhost:5000", + tool_names=["get-row-by-id", "search-rows"], + auth_token_getters={"my-test-auth": lambda: auth_token2}, + credentials=CredentialStrategy.toolbox_identity(), + ) + try: + with pytest.raises( + ValueError, + match=rf"Validation failed for list of tools 'get-row-by-id, search-rows': unused auth tokens could not be applied to any tool: my-test-auth", + ): + await toolset.get_tools() + finally: + await toolset.close() + + async def test_run_multiple_tools_partial_auth_usage(self, auth_token2: str): + """Tests that when some tokens are used and some aren't across diverse tools, only the truly unused tokens appear in the error.""" + toolset = ToolboxToolset( + server_url="http://localhost:5000", + tool_names=[ + "get-row-by-id-auth", + "search-rows", + ], # first requires 'my-test-auth', second requires nothing + auth_token_getters={ + "my-test-auth": lambda: auth_token2, + "extra-token": lambda: "fake", + }, + credentials=CredentialStrategy.toolbox_identity(), + ) + try: + with pytest.raises( + ValueError, + # 'my-test-auth' should be cleanly consumed and absent from the final error string. + match=r"Validation failed for list of tools 'get-row-by-id-auth, search-rows': unused auth tokens could not be applied to any tool: extra-token", ): await toolset.get_tools() finally: diff --git a/packages/toolbox-adk/tests/unit/test_tool.py b/packages/toolbox-adk/tests/unit/test_tool.py index 0f3c5120..2d0fb9d5 100644 --- a/packages/toolbox-adk/tests/unit/test_tool.py +++ b/packages/toolbox-adk/tests/unit/test_tool.py @@ -76,6 +76,48 @@ async def test_bind_params(self): assert new_tool._core_tool == new_core_mock mock_core.bind_params.assert_called_with({"a": 1}) + @pytest.mark.asyncio + async def test_dynamic_adk_token_getters(self): + core_tool = AsyncMock() + core_tool.__name__ = "mock" + core_tool.__doc__ = "mock doc" + core_tool._required_authn_params = {"param1": ["service1"]} + core_tool._required_authz_tokens = ["service2"] + core_tool.add_auth_token_getter = MagicMock(return_value=core_tool) + + def getter1(): + return "token1" + + def getter2(ctx): + return ctx.state.get("token2") + + adk_getters = { + "service1": getter1, + "service2": getter2, + } + + tool = ToolboxTool(core_tool, adk_token_getters=adk_getters) + + ctx = MagicMock() + ctx.state = {"token2": "dynamic_token2"} + + await tool.run_async({}, ctx) + + assert core_tool.add_auth_token_getter.call_count == 2 + + args1 = core_tool.add_auth_token_getter.call_args_list[0][0] + args2 = core_tool.add_auth_token_getter.call_args_list[1][0] + + # Because we iterate over items(), order might be dependent. + # Check that both services were processed and bound correctly + bound_getters = {args1[0]: args1[1], args2[0]: args2[1]} + + assert "service1" in bound_getters + assert bound_getters["service1"]() == "token1" + + assert "service2" in bound_getters + assert bound_getters["service2"]() == "dynamic_token2" + @pytest.mark.asyncio async def test_3lo_missing_client_secret(self): # Test ValueError when client_id/secret missing diff --git a/packages/toolbox-adk/tests/unit/test_toolset.py b/packages/toolbox-adk/tests/unit/test_toolset.py index 54962f5e..ff632e1e 100644 --- a/packages/toolbox-adk/tests/unit/test_toolset.py +++ b/packages/toolbox-adk/tests/unit/test_toolset.py @@ -49,12 +49,8 @@ async def test_get_tools_load_set_and_list(self, mock_client_cls): assert isinstance(tools[0], ToolboxTool) assert isinstance(tools[1], ToolboxTool) - mock_client.load_toolset.assert_awaited_with( - "set1", bound_params={"p": 1}, auth_token_getters={} - ) - mock_client.load_tool.assert_awaited_with( - "toolA", bound_params={"p": 1}, auth_token_getters={} - ) + mock_client.load_toolset.assert_awaited_with("set1", bound_params={"p": 1}) + mock_client.load_tool.assert_awaited_with("toolA", bound_params={"p": 1}) @patch("toolbox_adk.toolset.ToolboxClient") @pytest.mark.asyncio @@ -65,6 +61,8 @@ async def test_get_tools_with_auth_token_getters(self, mock_client_cls): t1 = MagicMock() t1.__name__ = "tool1" t1.__doc__ = "desc1" + t1._required_authn_params = {"param1": ["service"]} + t1._required_authz_tokens = [] mock_client.load_tool = AsyncMock(return_value=t1) auth_getters = {"service": lambda: "token"} @@ -75,10 +73,60 @@ async def test_get_tools_with_auth_token_getters(self, mock_client_cls): tools = await toolset.get_tools() assert len(tools) == 1 - mock_client.load_tool.assert_awaited_with( - "toolA", bound_params={}, auth_token_getters=auth_getters + mock_client.load_tool.assert_awaited_with("toolA", bound_params={}) + assert tools[0]._adk_token_getters == auth_getters + + @patch("toolbox_adk.toolset.ToolboxClient") + @pytest.mark.asyncio + @pytest.mark.parametrize( + "authn,authz,should_raise", + [ + ({}, [], True), # No requirements, token is completely unused + ({"param1": ["service"]}, [], False), # authn natively consumes it + ({}, ["service"], False), # authz natively consumes it + ( + {"param1": ["other"]}, + ["service"], + False, + ), # unused by authn, but authz consumes it + ( + {"param1": ["service"]}, + ["other"], + False, + ), # authn consumes it, authz doesn't + ( + {"param1": ["other"]}, + ["other"], + True, + ), # Requirements exist, but token is unused by both + ], + ) + async def test_get_tools_auth_validation( + self, mock_client_cls, authn, authz, should_raise + ): + mock_client = mock_client_cls.return_value + + t1 = MagicMock() + t1.__name__ = "tool1" + t1._required_authn_params = authn + t1._required_authz_tokens = authz + mock_client.load_tool = AsyncMock(return_value=t1) + + auth_getters = {"service": lambda: "token"} + toolset = ToolboxToolset( + "url", tool_names=["toolA"], auth_token_getters=auth_getters ) + if should_raise: + with pytest.raises( + ValueError, + match="unused auth tokens could not be applied to any tool: service", + ): + await toolset.get_tools() + else: + tools = await toolset.get_tools() + assert len(tools) == 1 + @patch("toolbox_adk.toolset.ToolboxClient") @pytest.mark.asyncio async def test_close(self, mock_client_cls): diff --git a/packages/toolbox-core/integration.cloudbuild.yaml b/packages/toolbox-core/integration.cloudbuild.yaml index cc28fc9f..e6dc930d 100644 --- a/packages/toolbox-core/integration.cloudbuild.yaml +++ b/packages/toolbox-core/integration.cloudbuild.yaml @@ -24,7 +24,8 @@ steps: uv venv /workspace/venv source /workspace/venv/bin/activate uv pip install uv - uv pip install -r requirements.txt -e '.[test]' + uv pip install -e '.[test]' + uv pip install -r requirements.txt entrypoint: /bin/bash - id: Run integration tests name: 'python:${_VERSION}' diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 79e598e6..d3a0009d 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -30,7 +30,12 @@ ) from .protocol import Protocol, ToolSchema from .tool import ToolboxTool -from .utils import identify_auth_requirements, resolve_value, warn_if_http_and_headers +from .utils import ( + identify_auth_requirements, + resolve_value, + validate_unused_requirements, + warn_if_http_and_headers, +) class ToolboxClient: @@ -238,20 +243,14 @@ async def load_tool( provided_auth_keys = set(auth_token_getters.keys()) provided_bound_keys = set(bound_params.keys()) - unused_auth = provided_auth_keys - used_auth_keys - unused_bound = provided_bound_keys - used_bound_keys - - if unused_auth or unused_bound: - error_messages = [] - if unused_auth: - error_messages.append(f"unused auth tokens: {', '.join(unused_auth)}") - if unused_bound: - error_messages.append( - f"unused bound parameters: {', '.join(unused_bound)}" - ) - raise ValueError( - f"Validation failed for tool '{name}': { '; '.join(error_messages) }." - ) + validate_unused_requirements( + provided_auth_keys, + provided_bound_keys, + used_auth_keys, + used_bound_keys, + name, + is_toolset=False, + ) return tool @@ -318,41 +317,26 @@ async def load_toolset( tools.append(tool) if strict: - unused_auth = provided_auth_keys - used_auth_keys - unused_bound = provided_bound_keys - used_bound_keys - if unused_auth or unused_bound: - error_messages = [] - if unused_auth: - error_messages.append( - f"unused auth tokens: {', '.join(unused_auth)}" - ) - if unused_bound: - error_messages.append( - f"unused bound parameters: {', '.join(unused_bound)}" - ) - raise ValueError( - f"Validation failed for tool '{tool_name}': { '; '.join(error_messages) }." - ) + validate_unused_requirements( + provided_auth_keys, + provided_bound_keys, + used_auth_keys, + used_bound_keys, + tool_name, + is_toolset=False, + ) else: overall_used_auth_keys.update(used_auth_keys) overall_used_bound_params.update(used_bound_keys) - unused_auth = provided_auth_keys - overall_used_auth_keys - unused_bound = provided_bound_keys - overall_used_bound_params - - if unused_auth or unused_bound: - error_messages = [] - if unused_auth: - error_messages.append( - f"unused auth tokens could not be applied to any tool: {', '.join(unused_auth)}" - ) - if unused_bound: - error_messages.append( - f"unused bound parameters could not be applied to any tool: {', '.join(unused_bound)}" - ) - raise ValueError( - f"Validation failed for toolset '{name or 'default'}': { '; '.join(error_messages) }." - ) + validate_unused_requirements( + provided_auth_keys, + provided_bound_keys, + overall_used_auth_keys, + overall_used_bound_params, + name or "default", + is_toolset=True, + ) return tools diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index 00c00157..fabde6c1 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -163,3 +163,47 @@ async def resolve_value( elif callable(source): return source() return source + + +def validate_unused_requirements( + provided_auth_keys: set[str], + provided_bound_keys: set[str], + used_auth_keys: set[str], + used_bound_keys: set[str], + name: str, + is_toolset: bool = False, + target_type: str | None = None, +) -> None: + """ + Validates that no provided authentication tokens or bound parameters went unused. + Raises a ValueError if any unused requirements are found, formatted appropriately + for either a single tool or a full toolset. + """ + unused_auth = provided_auth_keys - used_auth_keys + unused_bound = provided_bound_keys - used_bound_keys + + if unused_auth or unused_bound: + error_messages = [] + if unused_auth: + if is_toolset: + error_messages.append( + f"unused auth tokens could not be applied to any tool: {', '.join(unused_auth)}" + ) + else: + error_messages.append(f"unused auth tokens: {', '.join(unused_auth)}") + if unused_bound: + if is_toolset: + error_messages.append( + f"unused bound parameters could not be applied to any tool: {', '.join(unused_bound)}" + ) + else: + error_messages.append( + f"unused bound parameters: {', '.join(unused_bound)}" + ) + + final_target_type = ( + target_type if target_type else ("toolset" if is_toolset else "tool") + ) + raise ValueError( + f"Validation failed for {final_target_type} '{name}': {'; '.join(error_messages)}." + ) diff --git a/packages/toolbox-langchain/integration.cloudbuild.yaml b/packages/toolbox-langchain/integration.cloudbuild.yaml index 32e87863..ac26005d 100644 --- a/packages/toolbox-langchain/integration.cloudbuild.yaml +++ b/packages/toolbox-langchain/integration.cloudbuild.yaml @@ -28,7 +28,8 @@ steps: # Use $$ to escape shell variable for Cloud Build CORE_VERSION=$$(python -c "v={}; exec(open('../toolbox-core/src/toolbox_core/version.py').read(), v); print(v['__version__'])") sed -i "s/toolbox-core==[0-9.]*/toolbox-core==$$CORE_VERSION/g" pyproject.toml - uv pip install -r requirements.txt -e '.[test]' + uv pip install -r requirements.txt + uv pip install -e '.[test]' entrypoint: /bin/bash - id: Run integration tests name: 'python:${_VERSION}' diff --git a/packages/toolbox-llamaindex/integration.cloudbuild.yaml b/packages/toolbox-llamaindex/integration.cloudbuild.yaml index 5d7b5dc8..d3586bf2 100644 --- a/packages/toolbox-llamaindex/integration.cloudbuild.yaml +++ b/packages/toolbox-llamaindex/integration.cloudbuild.yaml @@ -28,7 +28,8 @@ steps: # Use $$ to escape shell variable for Cloud Build CORE_VERSION=$$(python -c "v={}; exec(open('../toolbox-core/src/toolbox_core/version.py').read(), v); print(v['__version__'])") sed -i "s/toolbox-core==[0-9.]*/toolbox-core==$$CORE_VERSION/g" pyproject.toml - uv pip install -r requirements.txt -e '.[test]' + uv pip install -r requirements.txt + uv pip install -e '.[test]' entrypoint: /bin/bash - id: Run integration tests name: 'python:${_VERSION}'