Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import logging
import mimetypes
import secrets
from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast

import pydantic
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
logger.debug("config=<%s> | initializing", self.config)

self.client_args = client_args or {}
self._tool_use_id_to_name: dict[str, str] = {}

@override
def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: ignore[override]
Expand Down Expand Up @@ -123,10 +125,13 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par
return genai.types.Part(text=content["text"])

if "toolResult" in content:
tool_use_id = content["toolResult"]["toolUseId"]
function_name = self._tool_use_id_to_name.get(tool_use_id, tool_use_id)

return genai.types.Part(
function_response=genai.types.FunctionResponse(
id=content["toolResult"]["toolUseId"],
name=content["toolResult"]["toolUseId"],
id=tool_use_id,
name=function_name,
response={
"output": [
tool_result_content
Expand All @@ -141,6 +146,12 @@ def _format_request_content_part(self, content: ContentBlock) -> genai.types.Par
)

if "toolUse" in content:
# Store the mapping from toolUseId to name for later use in toolResult formatting.
# This mapping is built as we format the request, ensuring that when we encounter
# toolResult blocks (which come after toolUse blocks in the message history),
# we can look up the function name.
self._tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"]

return genai.types.Part(
function_call=genai.types.FunctionCall(
args=content["toolUse"]["input"],
Expand Down Expand Up @@ -264,16 +275,16 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent:
case "content_start":
match event["data_type"]:
case "tool":
# Note: toolUseId is the only identifier available in a tool result. However, Gemini requires
# that name be set in the equivalent FunctionResponse type. Consequently, we assign
# function name to toolUseId in our tool use block. And another reason, function_call is
# not guaranteed to have id populated.
function_call = event["data"].function_call
# Use Gemini's provided ID or generate one if missing
tool_use_id = function_call.id or f"tooluse_{secrets.token_urlsafe(16)}"

return {
"contentBlockStart": {
"start": {
"toolUse": {
"name": event["data"].function_call.name,
"toolUseId": event["data"].function_call.name,
"name": function_call.name,
"toolUseId": tool_use_id,
},
},
},
Expand Down Expand Up @@ -364,6 +375,7 @@ async def stream(
ModelThrottledException: If the request is throttled by Gemini.
"""
request = self._format_request(messages, tool_specs, system_prompt, self.config.get("params"))
self._tool_use_id_to_name.clear()

client = genai.Client(**self.client_args).aio
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/strands/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ async def test_stream_response_tool_use(gemini_client, model, messages, agenerat
exp_chunks = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "calculator"}}}},
{"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}},
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}},
{"contentBlockStop": {}},
{"contentBlockStop": {}},
Expand Down