Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/lint-toolbox-adk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ jobs:
with:
python-version: "3.13"

- name: Install test requirements
run: pip install -e .[test]

- name: Install library requirements
run: pip install -r requirements.txt

- name: Install test requirements
run: pip install .[test]

- name: Run linters
run: |
black --check .
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/lint-toolbox-core.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ jobs:
with:
python-version: "3.13"

- name: Install test requirements
run: pip install -e .[test]

- name: Install library requirements
run: pip install -r requirements.txt

- name: Install test requirements
run: pip install .[test]

- name: Run linters
run: |
black --check .
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/lint-toolbox-langchain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ jobs:
with:
python-version: "3.13"

- name: Install test requirements
run: pip install -e .[test]

- name: Install library requirements
run: pip install -r requirements.txt

- name: Install test requirements
run: pip install .[test]

- name: Run linters
run: |
black --check .
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/lint-toolbox-llamaindex.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ jobs:
with:
python-version: "3.13"

- name: Install test requirements
run: pip install -e .[test]

- name: Install library requirements
run: pip install -r requirements.txt

- name: Install test requirements
run: pip install .[test]

- name: Run linters
run: |
black --check .
Expand Down
8 changes: 6 additions & 2 deletions packages/toolbox-adk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ creds = CredentialStrategy.from_adk_credentials(auth_credential, scheme)

Some tools may define their own authentication requirements (e.g., Salesforce OAuth, GitHub PAT) via `authSources` in their schema. You can provide a mapping of getters to resolve these tokens at runtime.

> [!TIP]
> Getters can optionally accept the ADK `ToolContext` as a single argument. This enables seamless integration of dynamic, end-user tokens that are tied to the current agent execution state.

```python
async def get_salesforce_token():
# Fetch token from secret manager or reliable source
Expand All @@ -218,8 +221,9 @@ async def get_salesforce_token():
toolset = ToolboxToolset(
server_url="...",
auth_token_getters={
"salesforce-auth": get_salesforce_token, # Async callable
"github-pat": lambda: "my-pat-token" # Sync callable or static lambda
"salesforce-auth": get_salesforce_token, # Async callable
"github-pat": lambda: "my-pat-token", # Sync callable or static lambda
"oauth-user": lambda ctx: ctx.state.get("auth_token") # Dynamic context-aware callable
}
)
```
Expand Down
3 changes: 2 additions & 1 deletion packages/toolbox-adk/integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ steps:
# Use $$ to escape shell variable for Cloud Build
CORE_VERSION=$$(python -c "v={}; exec(open('../toolbox-core/src/toolbox_core/version.py').read(), v); print(v['__version__'])")
sed -i "s/toolbox-core==[0-9.]*/toolbox-core==$$CORE_VERSION/g" pyproject.toml
uv pip install -r requirements.txt -e '.[test]'
uv pip install -r requirements.txt
uv pip install -e '.[test]'
entrypoint: /bin/bash
- id: Run integration tests
name: 'python:${_VERSION}'
Expand Down
29 changes: 27 additions & 2 deletions packages/toolbox-adk/src/toolbox_adk/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import logging
from typing import Any, Awaitable, Callable, Dict, Optional
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional

import google.adk.auth.exchanger.oauth2_credential_exchanger as oauth2_credential_exchanger
import google.adk.auth.oauth2_credential_util as oauth2_credential_util
Expand Down Expand Up @@ -68,11 +69,13 @@ def __init__(
self,
core_tool: CoreToolboxTool,
auth_config: Optional[CredentialConfig] = None,
adk_token_getters: Optional[Mapping[str, Any]] = None,
):
"""
Args:
core_tool: The underlying toolbox_core.py tool instance.
auth_config: Credential configuration to handle interactive flows.
adk_token_getters: Tool-specific auth token getters.
"""
# We act as a proxy.
# We need to extract metadata from the core tool to satisfy BaseTool's contract.
Expand All @@ -95,6 +98,7 @@ def __init__(
)
self._core_tool = core_tool
self._auth_config = auth_config
self._adk_token_getters = adk_token_getters or {}

def _param_type_to_schema_type(self, param_type: str) -> Type:
type_map = {
Expand Down Expand Up @@ -260,12 +264,32 @@ async def run_async(
"Falling back to request_credential.",
exc_info=True,
)
# Fallback to request logic
tool_context.request_credential(auth_config_adk)
return {
"error": f"OAuth2 Credentials required for {self.name}. A consent link has been generated for the user. Do NOT attempt to run this tool again until the user confirms they have logged in."
}

if self._adk_token_getters:
# Pre-filter toolset getters to avoid unused-token errors from the core tool.
# This deferred loop also enables dynamic 1-arity `tool_context` injection.
needed_services = set()
for reqs in self._core_tool._required_authn_params.values():
needed_services.update(reqs)
needed_services.update(self._core_tool._required_authz_tokens)

for service, getter in self._adk_token_getters.items():
if service in needed_services:
sig = inspect.signature(getter)

if len(sig.parameters) == 1:
bound_getter = lambda t=getter, ctx=tool_context: t(ctx)
else:
bound_getter = getter

self._core_tool = self._core_tool.add_auth_token_getter(
service, bound_getter
)

result: Optional[Any] = None
error: Optional[Exception] = None

Expand All @@ -288,4 +312,5 @@ def bind_params(self, bounded_params: Dict[str, Any]) -> "ToolboxTool":
return ToolboxTool(
core_tool=new_core_tool,
auth_config=self._auth_config,
adk_token_getters=self._adk_token_getters,
)
40 changes: 36 additions & 4 deletions packages/toolbox-adk/src/toolbox_adk/toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.base_toolset import BaseToolset
from google.adk.tools.tool_context import ToolContext
from toolbox_core.utils import validate_unused_requirements
from typing_extensions import override

from .client import ToolboxClient
Expand All @@ -41,7 +42,15 @@ def __init__(
] = None,
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
auth_token_getters: Optional[
Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]]]]
Mapping[
str,
Union[
Callable[[], str],
Callable[[], Awaitable[str]],
Callable[[ToolContext], str],
Callable[[ToolContext], Awaitable[str]],
],
]
] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -91,7 +100,6 @@ async def get_tools(
core_tools = await self.client.load_toolset(
self.__toolset_name,
bound_params=self.__bound_params or {},
auth_token_getters=self.__auth_token_getters or {},
)
tools.extend(core_tools)

Expand All @@ -101,7 +109,6 @@ async def get_tools(
core_tool = await self.client.load_tool(
name,
bound_params=self.__bound_params or {},
auth_token_getters=self.__auth_token_getters or {},
)
tools.append(core_tool)

Expand All @@ -110,15 +117,40 @@ async def get_tools(
core_tools = await self.client.load_toolset(
None,
bound_params=self.__bound_params or {},
auth_token_getters=self.__auth_token_getters or {},
)
tools.extend(core_tools)

# 4. Strictly validate unused toolset auth token getters using core logic
if self.__auth_token_getters:
overall_used_auth_keys = set()
for t in tools:
for reqs in t._required_authn_params.values():
overall_used_auth_keys.update(reqs)
overall_used_auth_keys.update(t._required_authz_tokens)

# Generate intuitive name for the error string if a specific toolset wasn't used
validation_name = self.__toolset_name
if not validation_name:
validation_name = (
", ".join(self.__tool_names) if self.__tool_names else "default"
)

validate_unused_requirements(
provided_auth_keys=set(self.__auth_token_getters.keys()),
provided_bound_keys=set(),
used_auth_keys=overall_used_auth_keys,
used_bound_keys=set(),
name=validation_name,
is_toolset=True,
target_type="list of tools" if not self.__toolset_name else None,
)

# Wrap all core tools in ToolboxTool
return [
ToolboxTool(
core_tool=t,
auth_config=self.client.credential_config,
adk_token_getters=self.__auth_token_getters,
)
for t in tools
]
Expand Down
43 changes: 42 additions & 1 deletion packages/toolbox-adk/tests/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,48 @@ async def test_run_tool_unauth_with_auth(self, auth_token2: str):
try:
with pytest.raises(
ValueError,
match=rf"Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth",
match=rf"Validation failed for list of tools 'get-row-by-id': unused auth tokens could not be applied to any tool: my-test-auth",
):
await toolset.get_tools()
finally:
await toolset.close()

async def test_run_multiple_tools_unauth_with_auth(self, auth_token2: str):
"""Tests running multiple tools that don't require auth, verifying formatting of tool lists."""
toolset = ToolboxToolset(
server_url="http://localhost:5000",
tool_names=["get-row-by-id", "search-rows"],
auth_token_getters={"my-test-auth": lambda: auth_token2},
credentials=CredentialStrategy.toolbox_identity(),
)
try:
with pytest.raises(
ValueError,
match=rf"Validation failed for list of tools 'get-row-by-id, search-rows': unused auth tokens could not be applied to any tool: my-test-auth",
):
await toolset.get_tools()
finally:
await toolset.close()

async def test_run_multiple_tools_partial_auth_usage(self, auth_token2: str):
"""Tests that when some tokens are used and some aren't across diverse tools, only the truly unused tokens appear in the error."""
toolset = ToolboxToolset(
server_url="http://localhost:5000",
tool_names=[
"get-row-by-id-auth",
"search-rows",
], # first requires 'my-test-auth', second requires nothing
auth_token_getters={
"my-test-auth": lambda: auth_token2,
"extra-token": lambda: "fake",
},
credentials=CredentialStrategy.toolbox_identity(),
)
try:
with pytest.raises(
ValueError,
# 'my-test-auth' should be cleanly consumed and absent from the final error string.
match=r"Validation failed for list of tools 'get-row-by-id-auth, search-rows': unused auth tokens could not be applied to any tool: extra-token",
):
await toolset.get_tools()
finally:
Expand Down
42 changes: 42 additions & 0 deletions packages/toolbox-adk/tests/unit/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,48 @@ async def test_bind_params(self):
assert new_tool._core_tool == new_core_mock
mock_core.bind_params.assert_called_with({"a": 1})

@pytest.mark.asyncio
async def test_dynamic_adk_token_getters(self):
core_tool = AsyncMock()
core_tool.__name__ = "mock"
core_tool.__doc__ = "mock doc"
core_tool._required_authn_params = {"param1": ["service1"]}
core_tool._required_authz_tokens = ["service2"]
core_tool.add_auth_token_getter = MagicMock(return_value=core_tool)

def getter1():
return "token1"

def getter2(ctx):
return ctx.state.get("token2")

adk_getters = {
"service1": getter1,
"service2": getter2,
}

tool = ToolboxTool(core_tool, adk_token_getters=adk_getters)

ctx = MagicMock()
ctx.state = {"token2": "dynamic_token2"}

await tool.run_async({}, ctx)

assert core_tool.add_auth_token_getter.call_count == 2

args1 = core_tool.add_auth_token_getter.call_args_list[0][0]
args2 = core_tool.add_auth_token_getter.call_args_list[1][0]

# Because we iterate over items(), order might be dependent.
# Check that both services were processed and bound correctly
bound_getters = {args1[0]: args1[1], args2[0]: args2[1]}

assert "service1" in bound_getters
assert bound_getters["service1"]() == "token1"

assert "service2" in bound_getters
assert bound_getters["service2"]() == "dynamic_token2"

@pytest.mark.asyncio
async def test_3lo_missing_client_secret(self):
# Test ValueError when client_id/secret missing
Expand Down
Loading
Loading