From 99305d793de04b9ab71592573904c1190b77622a Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 26 Mar 2026 12:14:01 -0500 Subject: [PATCH 01/15] feat(py): respond/restart improvements --- py/README.md | 2 +- py/packages/genkit/src/genkit/__init__.py | 16 +- py/packages/genkit/src/genkit/_ai/_aio.py | 104 ++++-- .../genkit/src/genkit/_ai/_generate.py | 102 +++++- py/packages/genkit/src/genkit/_ai/_prompt.py | 262 ++++++++------ py/packages/genkit/src/genkit/_ai/_tools.py | 333 ++++++++++++++++-- .../genkit/src/genkit/_core/_action.py | 28 ++ .../genkit/src/genkit/_core/_registry.py | 8 + .../genkit/tests/genkit/ai/prompt_test.py | 79 ++++- .../genkit/tests/genkit/veneer/veneer_test.py | 50 ++- .../plugins/google_genai/models/gemini.py | 49 ++- .../test/models/googlegenai_gemini_test.py | 33 ++ py/pyproject.toml | 2 +- py/samples/README.md | 2 +- py/samples/tool-interrupts/README.md | 30 +- .../prompts/bank_transfer_host_cli.prompt | 14 + .../prompts/trivia_host_cli.prompt | 9 + .../tool-interrupts/src/approval_example.py | 194 ++++++++++ py/samples/tool-interrupts/src/main.py | 118 ------- .../tool-interrupts/src/respond_example.py | 163 +++++++++ py/uv.lock | 2 +- 21 files changed, 1264 insertions(+), 336 deletions(-) create mode 100644 py/samples/tool-interrupts/prompts/bank_transfer_host_cli.prompt create mode 100644 py/samples/tool-interrupts/prompts/trivia_host_cli.prompt create mode 100644 py/samples/tool-interrupts/src/approval_example.py delete mode 100755 py/samples/tool-interrupts/src/main.py create mode 100755 py/samples/tool-interrupts/src/respond_example.py diff --git a/py/README.md b/py/README.md index a54a194bb0..0ddb09f373 100644 --- a/py/README.md +++ b/py/README.md @@ -424,7 +424,7 @@ The Python samples are intentionally short and beginner-friendly. Use this table | `middleware` | Middleware | Observe or modify model requests with `use=[...]` | | `output-formats` | Structured Output | Text, enum, JSON object, array, and JSONL outputs | | `prompts` | Prompt, Dotprompt | Use `.prompt` files, helpers, variants, and streaming | -| `tool-interrupts` | Tool | Pause tool execution for human approval | +| `tool-interrupts` | Tool | Trivia (`respond_example.py`) and bank approval (`approval_example.py`) for interrupt / resume | | `tracing` | Tracing | Watch spans appear in the Dev UI as they start | | `vertexai-imagen` | Vertex AI, Image Generation | Generate an image with Vertex AI Imagen | diff --git a/py/packages/genkit/src/genkit/__init__.py b/py/packages/genkit/src/genkit/__init__.py index 2af08aa319..a5185d3d56 100644 --- a/py/packages/genkit/src/genkit/__init__.py +++ b/py/packages/genkit/src/genkit/__init__.py @@ -17,13 +17,19 @@ """Genkit — Build AI-powered applications.""" from genkit._ai._aio import ActionKind, ActionRunContext, Genkit +from genkit._ai._generate import ToolReference from genkit._ai._prompt import ( ExecutablePrompt, ModelStreamResponse, PromptGenerateOptions, - ResumeOptions, ) -from genkit._ai._tools import ToolInterruptError, ToolRunContext, tool_response +from genkit._ai._tools import ( + Interrupt, + ToolInterruptError, + ToolRunContext, + respond_to_interrupt, + restart_interrupted_tool, +) from genkit._core._action import Action, StreamResponse from genkit._core._error import GenkitError, PublicError from genkit._core._model import Document @@ -94,6 +100,9 @@ 'GenkitError', 'PublicError', 'ToolInterruptError', + 'Interrupt', + 'respond_to_interrupt', + 'restart_interrupted_tool', # Content types 'Constrained', 'CustomPart', @@ -126,9 +135,8 @@ 'ActionRunContext', 'ExecutablePrompt', 'PromptGenerateOptions', - 'ResumeOptions', 'ToolRunContext', - 'tool_response', + 'ToolReference', 'ModelRequest', 'ModelResponse', 'ModelResponseChunk', diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index a1902bb150..b8fa6076d3 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -25,7 +25,7 @@ import socket import threading import uuid -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import Awaitable, Callable, Coroutine, Sequence from pathlib import Path from typing import Any, ParamSpec, TypeVar, cast, overload @@ -45,7 +45,7 @@ ) from genkit._ai._formats import built_in_formats from genkit._ai._formats._types import FormatDef -from genkit._ai._generate import define_generate_action, generate_action +from genkit._ai._generate import ToolReference, define_generate_action, generate_action from genkit._ai._model import ( Message, ModelConfig, @@ -71,7 +71,7 @@ ResourceOptions, define_resource, ) -from genkit._ai._tools import define_tool +from genkit._ai._tools import define_interrupt, define_tool from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._background import ( BackgroundAction, @@ -107,6 +107,8 @@ Part, SpanMetadata, ToolChoice, + ToolRequestPart, + ToolResponsePart, ) from ._decorators import _FlowDecorator, _FlowDecoratorWithChunk @@ -270,6 +272,42 @@ def wrapper(func: Callable[P, T]) -> Callable[P, T]: return wrapper + def define_interrupt( + self, + name: str, + *, + input_schema: type[BaseModel] | dict[str, object] | None = None, + output_schema: type[BaseModel] | dict[str, object] | None = None, + description: str | None = None, + ) -> Callable[P, T]: + """Register an interrupt tool that always pauses for user input. + + Args: + name: Tool name + input_schema: Optional input schema (Pydantic model or JSON schema dict) + output_schema: Optional output schema (Pydantic model or JSON schema dict) + description: Tool description + + Returns: + The registered interrupt tool + + Example: + ask_user = ai.define_interrupt( + name='ask_user', + input_schema=Question, + output_schema=Answer, + description='Ask the user a question', + ) + """ + + return define_interrupt( + self.registry, + name, + description=description, + input_schema=input_schema, + output_schema=output_schema, + ) + def define_evaluator( self, *, @@ -393,7 +431,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -421,7 +459,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -449,7 +487,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -477,7 +515,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -503,7 +541,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -741,10 +779,12 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, - tool_responses: list[Part] | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -766,10 +806,12 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, - tool_responses: list[Part] | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -789,10 +831,12 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, - tool_responses: list[Part] | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -817,7 +861,9 @@ async def generate( tools=tools, return_tool_requests=return_tool_requests, tool_choice=tool_choice, - tool_responses=tool_responses, + resume_respond=resume_respond, + resume_restart=resume_restart, + resume_metadata=resume_metadata, config=config, max_turns=max_turns, output_format=output_format, @@ -841,9 +887,12 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -866,9 +915,12 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -889,9 +941,12 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -904,7 +959,13 @@ def generate_stream( docs: list[Document] | None = None, timeout: float | None = None, ) -> ModelStreamResponse[Any]: - """Stream generated text, returning a ModelStreamResponse with .stream and .response.""" + """Stream generated text, returning a ModelStreamResponse with .stream and .response. + + Middleware (``use=``) uses the same signatures as :meth:`generate`: + - Simple: ``(req, ctx, next)`` — 3 params + - Streaming-aware: ``(req, ctx, on_chunk, next)`` — 4 params + The framework auto-detects by parameter count. Both work with generate_stream. + """ channel: Channel[ModelResponseChunk, ModelResponse[Any]] = Channel(timeout=timeout) async def _run_generate() -> ModelResponse[Any]: @@ -920,6 +981,9 @@ async def _run_generate() -> ModelResponse[Any]: tools=tools, return_tool_requests=return_tool_requests, tool_choice=tool_choice, + resume_respond=resume_respond, + resume_restart=resume_restart, + resume_metadata=resume_metadata, config=config, max_turns=max_turns, output_format=output_format, @@ -1130,7 +1194,7 @@ async def generate_operation( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, config: dict[str, object] | ModelConfig | None = None, diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index 8d24cc9ba0..a8839e0a9a 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -19,8 +19,8 @@ import copy import inspect import re -from collections.abc import Callable -from typing import Any, cast +from collections.abc import Awaitable, Callable, Sequence +from typing import Any, TypeAlias, cast from pydantic import BaseModel @@ -35,7 +35,7 @@ ModelResponseChunk, ) from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources -from genkit._ai._tools import ToolInterruptError +from genkit._ai._tools import ToolInterruptError, run_tool_after_restart from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._error import GenkitError from genkit._core._logger import get_logger @@ -56,6 +56,34 @@ logger = get_logger(__name__) +# Allowed ``tools=`` values for :meth:`~genkit.Genkit.generate`: +# - tool name (``str``) +# - :class:`~genkit.Action` (e.g. from prompt ``as_tool()``) +# - a function decorated with ``@ai.tool()`` (the same object you registered as a tool) +# Use ``Sequence`` instead of ``list``: type checkers treat ``list`` as strict about the +# exact item type, so ``[my_tool_fn]`` can fail to match ``list[str | Action | ...]`` even +# though it works at runtime. ``Sequence`` does not have that problem. +ToolReference: TypeAlias = str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]] + + +def tools_to_action_refs(tools: Sequence[ToolReference] | None) -> list[str] | None: + """Normalize tool arguments to registry names for :class:`GenerateActionOptions`.""" + if tools is None: + return None + refs: list[str] = [] + for t in tools: + if isinstance(t, str): + refs.append(t) + elif isinstance(t, Action): + refs.append(t.name) + else: + name = getattr(t, '__name__', None) + if not isinstance(name, str) or not name: + msg = f'Cannot resolve tool name from callable: {t!r}' + raise TypeError(msg) + refs.append(name) + return refs + # Matches data URIs: everything up to the first comma is the media-type + # parameters (e.g. "data:audio/L16;codec=pcm;rate=24000;base64,"). @@ -542,10 +570,7 @@ async def resolve_parameters( tools: list[Action[Any, Any, Any]] = [] if request.tools: for tool_name in request.tools: - tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name) - if tool_action is None: - raise Exception(f'Unable to resolve tool {tool_name}') - tools.append(tool_action) + tools.append(await resolve_tool(registry, tool_name)) format_def: FormatDef | None = None if request.output and request.output.format: @@ -695,11 +720,18 @@ async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart async def resolve_tool(registry: Registry, tool_name: str) -> Action: - """Resolve a tool by name from the registry.""" - tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_name) - if tool is None: - raise ValueError(f'Unable to resolve tool {tool_name}') - return tool + """Resolve a tool by name from the registry. + + Tries :class:`~genkit.ActionKind.TOOL` first, then prompt actions registered as + :class:`~genkit.ActionKind.PROMPT` / :class:`~genkit.ActionKind.EXECUTABLE_PROMPT` + (e.g. :meth:`~genkit.ExecutablePrompt.as_tool`). + """ + for kind in (ActionKind.TOOL, ActionKind.PROMPT, ActionKind.EXECUTABLE_PROMPT): + tool = await registry.resolve_action(kind=kind, name=tool_name) + if tool is not None: + return tool + msg = f'Unable to resolve tool {tool_name}' + raise ValueError(msg) async def _resolve_resume_options( @@ -730,7 +762,7 @@ async def _resolve_resume_options( i += 1 continue - resumed_request, resumed_response = _resolve_resumed_tool_request(raw_request, part) + resumed_request, resumed_response = await _resolve_resumed_tool_request(_registry, raw_request, part) tool_responses.append(resumed_response) updated_content[i] = resumed_request i += 1 @@ -755,8 +787,10 @@ async def _resolve_resume_options( return (revised_request, None, tool_message) -def _resolve_resumed_tool_request(raw_request: GenerateActionOptions, tool_request_part: Part) -> tuple[Part, Part]: - """Resolve a single tool request from pending output or resume.respond list.""" +async def _resolve_resumed_tool_request( + registry: Registry, raw_request: GenerateActionOptions, tool_request_part: Part +) -> tuple[Part, Part]: + """Resolve a single tool request from pending output, resume.respond, or resume.restart.""" # Type narrowing: ensure we're working with a ToolRequestPart if not isinstance(tool_request_part.root, ToolRequestPart): raise GenkitError( @@ -812,6 +846,31 @@ def _resolve_resumed_tool_request(raw_request: GenerateActionOptions, tool_reque provided_response, ) + restart_trp = _find_corresponding_restart( + raw_request.resume.restart if raw_request.resume else None, + tool_req_root, + ) + if restart_trp: + tool = await resolve_tool(registry, tool_req_root.tool_request.name) + executed = await run_tool_after_restart(tool, restart_trp) + metadata = dict(tool_req_root.metadata) if tool_req_root.metadata else {} + interrupt = metadata.get('interrupt') + if interrupt: + del metadata['interrupt'] + return ( + Part( + root=ToolRequestPart( + tool_request=ToolRequest( + name=tool_req_root.tool_request.name, + ref=tool_req_root.tool_request.ref, + input=tool_req_root.tool_request.input, + ), + metadata={**metadata, 'resolvedInterrupt': interrupt}, + ) + ), + executed, + ) + raise GenkitError( status='INVALID_ARGUMENT', message=f"Unresolved tool request '{tool_req_root.tool_request.name}' " @@ -820,6 +879,19 @@ def _resolve_resumed_tool_request(raw_request: GenerateActionOptions, tool_reque ) +def _find_corresponding_restart( + restarts: list[ToolRequestPart] | None, + request: ToolRequestPart, +) -> ToolRequestPart | None: + """Find a restart part matching the pending request by name and ref.""" + if not restarts: + return None + for trp in restarts: + if trp.tool_request.name == request.tool_request.name and trp.tool_request.ref == request.tool_request.ref: + return trp + return None + + def _find_corresponding_tool_response(responses: list[ToolResponsePart], request: ToolRequestPart) -> Part | None: """Find a response matching the request by name and ref.""" for p in responses: diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index 82e0985ab3..664fab6f3f 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -20,11 +20,13 @@ import asyncio import os import weakref -from collections.abc import AsyncIterable, Awaitable, Callable +from collections.abc import AsyncIterable, Awaitable, Callable, Sequence from dataclasses import dataclass from pathlib import Path from typing import Any, ClassVar, Generic, TypedDict, TypeVar, cast +from typing_extensions import Unpack + from dotpromptz.typing import ( DataArgument, PromptFunction, @@ -32,10 +34,12 @@ PromptMetadata, ) from pydantic import BaseModel, ConfigDict - from genkit._ai._generate import ( + ToolReference, generate_action, + resolve_tool, to_tool_definition, + tools_to_action_refs, ) from genkit._ai._model import ( Message, @@ -83,12 +87,34 @@ class OutputOptions(TypedDict, total=False): constrained: bool | None -class ResumeOptions(TypedDict, total=False): - """Options for resuming generation after a tool interrupt.""" +def _normalize_resume_respond_parts( + value: ToolResponsePart | list[ToolResponsePart] | None, +) -> list[ToolResponsePart] | None: + if value is None: + return None + return list(value) if isinstance(value, list) else [value] + + +def _normalize_resume_restart_parts( + value: ToolRequestPart | list[ToolRequestPart] | None, +) -> list[ToolRequestPart] | None: + if value is None: + return None + return list(value) if isinstance(value, list) else [value] - respond: ToolResponsePart | list[ToolResponsePart] | None - restart: ToolRequestPart | list[ToolRequestPart] | None - metadata: dict[str, Any] | None + +def resume_options_to_resume( + *, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, +) -> Resume | None: + """Build wire :class:`Resume` from flat keyword options (``generate`` / prompts).""" + respond = _normalize_resume_respond_parts(resume_respond) + restart = _normalize_resume_restart_parts(resume_restart) + if respond is None and restart is None and resume_metadata is None: + return None + return Resume(respond=respond, restart=restart, metadata=resume_metadata) class PromptGenerateOptions(TypedDict, total=False): @@ -98,11 +124,13 @@ class PromptGenerateOptions(TypedDict, total=False): config: dict[str, Any] | ModelConfig | None messages: list[Message] | None docs: list[Document] | None - tools: list[str] | None + tools: Sequence[ToolReference] | None resources: list[str] | None tool_choice: ToolChoice | None output: OutputOptions | None - resume: ResumeOptions | None + resume_respond: ToolResponsePart | list[ToolResponsePart] | None + resume_restart: ToolRequestPart | list[ToolRequestPart] | None + resume_metadata: dict[str, Any] | None return_tool_requests: bool | None max_turns: int | None on_chunk: ModelStreamingCallback | None @@ -112,6 +140,27 @@ class PromptGenerateOptions(TypedDict, total=False): metadata: dict[str, Any] | None +_PROMPT_GENERATE_OPTION_KEYS: frozenset[str] = frozenset(PromptGenerateOptions.__annotations__) + + +def _coerce_prompt_opts(opts: PromptGenerateOptions) -> PromptGenerateOptions: + """Build effective opts from a kwargs mapping: drop keys whose value is None. + + Rejects unknown keys at runtime (Unpack only enforces this for type checkers). + """ + raw = cast(dict[str, Any], opts) + unknown = set(raw) - _PROMPT_GENERATE_OPTION_KEYS + if unknown: + if 'opts' in unknown: + raise TypeError( + "Passing a combined `opts` dict is not supported; use keyword arguments " + '(e.g. model=..., config=...) matching PromptGenerateOptions.' + ) + sorted_unknown = ', '.join(sorted(unknown)) + raise TypeError(f'Unexpected keyword arguments for prompt execution: {sorted_unknown}') + return cast(PromptGenerateOptions, {k: v for k, v in raw.items() if v is not None}) + + class ModelStreamResponse(Generic[OutputT]): """Response from streaming prompt execution with stream and response properties.""" @@ -183,11 +232,13 @@ class PromptConfig(BaseModel): max_turns: int | None = None return_tool_requests: bool | None = None metadata: dict[str, Any] | None = None - tools: list[str] | None = None + tools: Sequence[ToolReference] | None = None tool_choice: ToolChoice | None = None use: list[ModelMiddleware] | None = None docs: list[Document] | None = None - tool_responses: list[Part] | None = None + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None + resume_metadata: dict[str, Any] | None = None resources: list[str] | None = None @@ -213,7 +264,7 @@ def __init__( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, Any] | None = None, - tools: list[str] | None = None, + tools: Sequence[ToolReference] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -295,60 +346,42 @@ async def _ensure_resolved(self) -> None: async def __call__( self, - input: InputT | None = None, - opts: PromptGenerateOptions | None = None, + input: InputT | dict[str, Any] | None = None, + **opts: Unpack[PromptGenerateOptions], ) -> ModelResponse[OutputT]: - """Execute the prompt and return the response.""" - await self._ensure_resolved() - effective_opts: PromptGenerateOptions = opts if opts else {} + """Execute the prompt and return the response. - # Extract streaming callback and middleware from opts - on_chunk = effective_opts.get('on_chunk') - middleware = effective_opts.get('use') or self._use - context = effective_opts.get('context') + Args: + input: Template variables for rendering. + """ + effective_opts = _coerce_prompt_opts(cast(PromptGenerateOptions, opts)) + return await self._call_impl(input, effective_opts) + async def _call_impl( + self, + input: InputT | dict[str, Any] | None, + opts: PromptGenerateOptions, + ) -> ModelResponse[OutputT]: + """Execute the prompt with resolved opts. Used by __call__ and stream.""" + await self._ensure_resolved() + on_chunk = opts.get('on_chunk') + middleware = opts.get('use') or self._use + context = opts.get('context') result = await generate_action( self._registry, - await self.render(input=input, opts=effective_opts), + await self._render_impl(input, opts), on_chunk=on_chunk, middleware=middleware, context=context if context else ActionRunContext._current_context(), # pyright: ignore[reportPrivateUsage] ) - # Cast to preserve the generic type parameter return cast(ModelResponse[OutputT], result) - def stream( + async def _render_impl( self, - input: InputT | None = None, - opts: PromptGenerateOptions | None = None, - *, - timeout: float | None = None, - ) -> ModelStreamResponse[OutputT]: - """Stream the prompt execution, returning (stream, response_future).""" - effective_opts: PromptGenerateOptions = opts if opts else {} - channel: Channel[ModelResponseChunk, ModelResponse[OutputT]] = Channel(timeout=timeout) - - # Create a copy of opts with the streaming callback - stream_opts: PromptGenerateOptions = { - **effective_opts, - 'on_chunk': lambda c: channel.send(cast(ModelResponseChunk, c)), - } - - resp = self.__call__(input=input, opts=stream_opts) - response_future: asyncio.Future[ModelResponse[OutputT]] = asyncio.create_task(resp) - channel.set_close_future(response_future) - - return ModelStreamResponse[OutputT](channel=channel, response_future=response_future) - - async def render( - self, - input: InputT | dict[str, Any] | None = None, - opts: PromptGenerateOptions | None = None, + input: InputT | dict[str, Any] | None, + opts: PromptGenerateOptions, ) -> GenerateActionOptions: - """Render the prompt template without executing, returning GenerateActionOptions.""" - await self._ensure_resolved() - if opts is None: - opts = cast(PromptGenerateOptions, {}) + """Render the prompt with resolved opts. Used by render() and _call_impl.""" output_opts = opts.get('output') or {} context = opts.get('context') @@ -395,46 +428,37 @@ def _or(opt_val: Any, default: Any) -> Any: # noqa: ANN401 metadata=merged_metadata, docs=self._docs, resources=opts.get('resources') or self._resources, + resume_respond=opts.get('resume_respond'), + resume_restart=opts.get('resume_restart'), + resume_metadata=opts.get('resume_metadata'), ) - model = prompt_options.model or self._registry.default_model - if model is None: + model_name = prompt_options.model or self._registry.default_model + if model_name is None: raise GenkitError(status='INVALID_ARGUMENT', message='No model configured.') resolved_msgs: list[Message] = [] # Convert input to dict for render functions - # If input is a Pydantic model, convert to dict; otherwise use as-is render_input: dict[str, Any] if input is None: render_input = {} elif isinstance(input, dict): - # Type narrow: input is dict here, assign to dict[str, Any] typed variable render_input = {str(k): v for k, v in input.items()} elif isinstance(input, BaseModel): - # Pydantic v2 model render_input = input.model_dump() elif hasattr(input, 'dict'): - # Pydantic v1 model dict_func = getattr(input, 'dict', None) render_input = cast(Callable[[], dict[str, Any]], dict_func)() else: - # Fallback: cast to dict (should not happen with proper typing) render_input = cast(dict[str, Any], input) - # Get opts.messages for history (matching JS behavior) - opts_messages = opts.get('messages') - # Render system prompt + opts_messages = opts.get('messages') if prompt_options.system: result = await render_system_prompt( self._registry, render_input, prompt_options, self._cache_prompt, context ) resolved_msgs.append(result) - - # Handle messages (matching JS behavior): - # - If prompt has messages template: render it (opts.messages passed as history to resolvers) - # - If prompt has no messages: use opts.messages directly if prompt_options.messages: - # Prompt defines messages - render them (resolvers receive opts_messages as history) resolved_msgs.extend( await render_message_prompt( self._registry, @@ -446,59 +470,84 @@ def _or(opt_val: Any, default: Any) -> Any: # noqa: ANN401 ) ) elif opts_messages: - # Prompt has no messages template - use opts.messages directly resolved_msgs.extend(opts_messages) - - # Render user prompt if prompt_options.prompt: result = await render_user_prompt(self._registry, render_input, prompt_options, self._cache_prompt, context) resolved_msgs.append(result) - # If schema is set but format is not explicitly set, default to 'json' format if prompt_options.output_schema and not prompt_options.output_format: - output_format = 'json' + out_format = 'json' else: - output_format = prompt_options.output_format + out_format = prompt_options.output_format - # Build output config - output = GenerateActionOutputConfig() - if output_format: - output.format = output_format + output_config = GenerateActionOutputConfig() + if out_format: + output_config.format = out_format if prompt_options.output_content_type: - output.content_type = prompt_options.output_content_type + output_config.content_type = prompt_options.output_content_type if prompt_options.output_instructions is not None: - output.instructions = prompt_options.output_instructions - _resolve_output_schema(self._registry, prompt_options.output_schema, output) + output_config.instructions = prompt_options.output_instructions + _resolve_output_schema(self._registry, prompt_options.output_schema, output_config) if prompt_options.output_constrained is not None: - output.constrained = prompt_options.output_constrained + output_config.constrained = prompt_options.output_constrained - # Handle resume options - resume = None - resume_opts = opts.get('resume') - if resume_opts: - respond = resume_opts.get('respond') - if respond: - resume = Resume(respond=respond) if isinstance(respond, list) else Resume(respond=[respond]) + resume_result = resume_options_to_resume( + resume_respond=opts.get('resume_respond'), + resume_restart=opts.get('resume_restart'), + resume_metadata=opts.get('resume_metadata'), + ) - # Merge docs: opts.docs extends prompt docs merged_docs = await render_docs(render_input, prompt_options, context) opts_docs = opts.get('docs') if opts_docs: merged_docs = [*merged_docs, *opts_docs] if merged_docs else list(opts_docs) return GenerateActionOptions( - model=model, - messages=resolved_msgs, + model=model_name, + messages=resolved_msgs, # type: ignore[arg-type] config=prompt_options.config, - tools=prompt_options.tools, + tools=tools_to_action_refs(prompt_options.tools), return_tool_requests=prompt_options.return_tool_requests, tool_choice=prompt_options.tool_choice, - output=output, + output=output_config, max_turns=prompt_options.max_turns, docs=merged_docs, # type: ignore[arg-type] - resume=resume, + resume=resume_result, ) + def stream( + self, + input: InputT | dict[str, Any] | None = None, + *, + timeout: float | None = None, + **opts: Unpack[PromptGenerateOptions], + ) -> ModelStreamResponse[OutputT]: + """Stream the prompt execution, returning (stream, response_future).""" + effective_opts = _coerce_prompt_opts(cast(PromptGenerateOptions, opts)) + channel: Channel[ModelResponseChunk, ModelResponse[OutputT]] = Channel(timeout=timeout) + stream_opts: PromptGenerateOptions = { + **effective_opts, + 'on_chunk': lambda c: channel.send(cast(ModelResponseChunk, c)), + } + resp = self._call_impl(input, stream_opts) + response_future: asyncio.Future[ModelResponse[OutputT]] = asyncio.create_task(resp) + channel.set_close_future(response_future) + + return ModelStreamResponse[OutputT](channel=channel, response_future=response_future) + + async def render( + self, + input: InputT | dict[str, Any] | None = None, + **opts: Unpack[PromptGenerateOptions], + ) -> GenerateActionOptions: + """Render the prompt template without executing, returning GenerateActionOptions. + + Same keyword options as :meth:`__call__` (see :class:`PromptGenerateOptions`). + """ + await self._ensure_resolved() + coerced = _coerce_prompt_opts(cast(PromptGenerateOptions, opts)) + return await self._render_impl(input, coerced) + async def as_tool(self) -> Action: """Expose this prompt as a tool. @@ -651,18 +700,20 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig) if options.output_constrained is not None: output.constrained = options.output_constrained - resume = None - if options.tool_responses: - # Filter for only ToolResponsePart instances - tool_response_parts = [r.root for r in options.tool_responses if isinstance(r.root, ToolResponsePart)] - if tool_response_parts: - resume = Resume(respond=tool_response_parts) + resume = resume_options_to_resume( + resume_respond=options.resume_respond, + resume_restart=options.resume_restart, + resume_metadata=options.resume_metadata, + ) + + # Convert tool refs (name, Action, or @tool callable) to string names for GenerateActionOptions + tools_refs = tools_to_action_refs(options.tools) return GenerateActionOptions( model=model, messages=resolved_msgs, # type: ignore[arg-type] config=options.config, - tools=options.tools, + tools=tools_refs, return_tool_requests=options.return_tool_requests, tool_choice=options.tool_choice, output=output, @@ -676,11 +727,8 @@ async def to_generate_request(registry: Registry, options: GenerateActionOptions """Convert GenerateActionOptions to ModelRequest, resolving tool names.""" tools: list[Action] = [] if options.tools: - for tool_name in options.tools: - tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name) - if tool_action is None: - raise GenkitError(status='NOT_FOUND', message=f'Unable to resolve tool {tool_name}') - tools.append(tool_action) + for tool_ref in options.tools: + tools.append(await resolve_tool(registry, tool_ref)) tool_defs = [to_tool_definition(tool) for tool in tools] if tools else [] diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index f603f66f79..d7bf09e3de 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -18,16 +18,24 @@ import inspect from collections.abc import Callable +from contextvars import ContextVar from functools import wraps -from typing import Any, NoReturn, ParamSpec, TypeVar, cast +from typing import Any, ParamSpec, TypeVar, cast -from genkit._core._action import ActionKind, ActionRunContext +from pydantic import BaseModel + +from genkit._core._action import Action, ActionKind, ActionRunContext +from genkit._core._error import GenkitError from genkit._core._registry import Registry from genkit._core._typing import Part, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart P = ParamSpec('P') T = TypeVar('T') +# Context variables for propagating resumed metadata to tools +_tool_resumed_metadata: ContextVar[dict[str, Any] | None] = ContextVar('tool_resumed_metadata', default=None) +_tool_original_input: ContextVar[Any | None] = ContextVar('tool_original_input', default=None) + class ToolRunContext(ActionRunContext): """Tool execution context with interrupt support.""" @@ -35,13 +43,23 @@ class ToolRunContext(ActionRunContext): def __init__( self, ctx: ActionRunContext, + resumed_metadata: dict[str, Any] | None = None, + original_input: Any = None, ) -> None: - """Initialize from parent ActionRunContext.""" + """Initialize from parent ActionRunContext. + + Args: + ctx: Parent action context + resumed_metadata: Metadata from previous interrupt (if resumed) + original_input: Original tool input before replacement (if resumed) + """ super().__init__(context=ctx.context) + self.resumed_metadata = resumed_metadata + self.original_input = original_input - def interrupt(self, metadata: dict[str, Any] | None = None) -> NoReturn: - """Raise ToolInterruptError to pause execution (e.g., for user input).""" - raise ToolInterruptError(metadata=metadata) + def is_resumed(self) -> bool: + """Return True if this execution is resuming after an interrupt.""" + return self.resumed_metadata is not None # TODO(#4346): make this extend GenkitError once it has INTERRUPTED status @@ -54,30 +72,143 @@ def __init__(self, metadata: dict[str, Any] | None = None) -> None: self.metadata: dict[str, Any] = metadata or {} -def tool_response( - interrupt: Part | ToolRequestPart, - response_data: object | None = None, +class Interrupt(Exception): + """Exception for interrupting tool execution with user-facing API. + + Use ``raise Interrupt(data)`` inside a tool to pause. Prefer this over raising + :class:`ToolInterruptError` directly. + + To resume, use :func:`respond_to_interrupt` or :func:`restart_interrupted_tool` (or + ``tool.restart`` on the registered tool callable). + """ + + def __init__(self, data: dict[str, Any] | None = None) -> None: + """Initialize an Interrupt exception. + + Args: + data: Interrupt metadata (attached to the tool request on the wire). Use a + plain dict; for a Pydantic model, pass ``m.model_dump(mode="json")``. + """ + super().__init__() + self.data: dict[str, Any] = {} if data is None else data + + +def _tool_response_part( + interrupt: ToolRequestPart, + output: Any, + metadata: dict[str, Any] | None = None, +) -> ToolResponsePart: + """Build a ``ToolResponsePart`` for an interrupted tool request (interrupt reply channel).""" + interrupt_metadata = metadata if metadata is not None else True + tool_req = interrupt.tool_request + return ToolResponsePart( + tool_response=ToolResponse( + ref=tool_req.ref, + name=tool_req.name, + output=output, + ), + metadata={'interruptResponse': interrupt_metadata}, + ) + + +def respond_to_interrupt( + response: Any, + *, + interrupt: ToolRequestPart, metadata: dict[str, Any] | None = None, -) -> Part: - """Create a ToolResponse Part for an interrupted tool request.""" - # TODO(#4347): validate against tool schema - tool_request = interrupt.root.tool_request if isinstance(interrupt, Part) else interrupt.tool_request - - interrupt_metadata: dict[str, Any] | bool = True - if isinstance(metadata, dict): - interrupt_metadata = metadata - elif metadata: - interrupt_metadata = metadata - - tr = cast(ToolRequest, tool_request) +) -> ToolResponsePart: + """Build a ``ToolResponsePart`` for a pending tool interrupt. + + Pass the return value to ``generate(..., resume_respond=interrupt_response)``. + + Args: + response: Tool output / user reply for this interrupt. + interrupt: The interrupted ``ToolRequestPart`` (e.g. from ``response.interrupts``). + metadata: Optional metadata for the interrupt response channel. + """ + return _tool_response_part(interrupt, response, metadata) + + +def restart_interrupted_tool( + replace_input: Any | None = None, # noqa: ANN401 + *, + interrupt: ToolRequestPart, + tool: Callable[..., Any], + resumed_metadata: dict[str, Any] | None = None, +) -> ToolRequestPart: + """Build a restart ``ToolRequestPart`` for an interrupted tool call. + + Thin wrapper around ``tool.restart(...)`` on the registered tool callable (from + :func:`define_tool` or :meth:`Genkit.tool`). Use when you want a module-level helper + parallel to :func:`respond_to_interrupt`. + + Args: + replace_input: New ``tool_request.input`` for the redo, or ``None`` to keep the + interrupted request's input. + interrupt: The interrupted ``ToolRequestPart`` (e.g. from ``response.interrupts``). + tool: The same tool function / wrapper that was registered (must expose ``.restart``). + resumed_metadata: Passed through to :attr:`ToolRunContext.resumed_metadata`. + + Returns: + A ``ToolRequestPart`` suitable for ``generate(..., resume_restart=[...])`` / message history. + """ + restart = getattr(tool, 'restart', None) + if restart is None: + raise TypeError( + f'{tool!r} has no restart method; pass a tool from define_tool or Genkit.tool' + ) + part = restart( + replace_input, + interrupt=interrupt, + resumed_metadata=resumed_metadata, + ) + root = part.root + if not isinstance(root, ToolRequestPart): + msg = f'Expected tool.restart() to return Part(root=ToolRequestPart), got {type(root)!r}' + raise TypeError(msg) + return root + + +async def run_tool_after_restart(tool: Action[Any, Any, Any], restart_trp: ToolRequestPart) -> Part: + """Run a tool for ``resume_restart``: applies ``resumed`` / ``replacedInput`` from metadata. + + Sets the same context variables as the tool wrapper so :class:`ToolRunContext` reflects + a resumed run. Nested interrupts during restart are not supported and raise :class:`GenkitError`. + """ + meta = restart_trp.metadata or {} + raw_resumed = meta.get('resumed') + if raw_resumed is True: + resumed_meta: dict[str, Any] | None = {} + elif isinstance(raw_resumed, dict): + resumed_meta = raw_resumed + else: + resumed_meta = None + original_input = meta.get('replacedInput') + + token_meta = _tool_resumed_metadata.set(resumed_meta) + token_input = _tool_original_input.set(original_input) + try: + try: + tool_response = (await tool.run(restart_trp.tool_request.input)).response + except GenkitError as e: + if e.cause and isinstance(e.cause, ToolInterruptError): + raise GenkitError( + status='FAILED_PRECONDITION', + message='Tool interrupted again during a restart execution; not supported yet.', + cause=e.cause, + ) from e + raise + finally: + _tool_resumed_metadata.reset(token_meta) + _tool_original_input.reset(token_input) + return Part( root=ToolResponsePart( tool_response=ToolResponse( - name=tr.name, - ref=tr.ref, - output=response_data, - ), - metadata={'interruptResponse': interrupt_metadata}, + name=restart_trp.tool_request.name, + ref=restart_trp.tool_request.ref, + output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, + ) ) ) @@ -96,6 +227,9 @@ def define_tool( func: Callable[P, T], name: str | None = None, description: str | None = None, + *, + input_schema: type[BaseModel] | dict[str, object] | None = None, + output_schema: type[BaseModel] | dict[str, object] | None = None, ) -> Callable[P, T]: """Register a function as a tool. @@ -104,6 +238,8 @@ def define_tool( func: The async function to register as a tool. Must be a coroutine function. name: Optional name for the tool. Defaults to the function name. description: Optional description. Defaults to the function's docstring. + input_schema: Optional override for tool input JSON schema / validation (Pydantic model or dict). + output_schema: Optional override for tool output JSON schema (Pydantic model or dict). Raises: TypeError: If func is not an async function. @@ -121,15 +257,29 @@ def define_tool( async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 # Dynamic dispatch based on function signature - pyright can't verify ParamSpec here - match len(input_spec.args): - case 0: - return await func_any() - case 1: - return await func_any(args[0]) - case 2: - return await func_any(args[0], ToolRunContext(cast(ActionRunContext, args[1]))) - case _: - raise ValueError('tool must have 0-2 args...') + try: + match len(input_spec.args): + case 0: + return await func_any() + case 1: + return await func_any(args[0]) + case 2: + # Read from context variables for resumed metadata + resumed_meta = _tool_resumed_metadata.get() + original_input = _tool_original_input.get() + return await func_any( + args[0], + ToolRunContext( + cast(ActionRunContext, args[1]), + resumed_metadata=resumed_meta, + original_input=original_input, + ), + ) + case _: + raise ValueError('tool must have 0-2 args...') + except Interrupt as e: + # Convert Interrupt to ToolInterruptError for compatibility with existing flow + raise ToolInterruptError(metadata=e.data) from e action = registry.register_action( name=tool_name, @@ -137,6 +287,8 @@ async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 description=tool_description, fn=tool_fn_wrapper, metadata_fn=func, + input_schema=input_schema, + output_schema=output_schema, ) @wraps(func) @@ -144,4 +296,115 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any: # noqa: ANN401 action_any = cast(Any, action) return (await action_any.run(*args, **kwargs)).response + # Add restart method to the wrapper (use respond_to_interrupt for interrupt replies). + def restart( + replace_input: Any | None = None, # noqa: ANN401 + *, + interrupt: ToolRequestPart, + resumed_metadata: dict[str, Any] | None = None, + ) -> Part: + """Create a restart request for an interrupted tool call. + + Args: + replace_input: Optional new ``tool_request.input`` for this run (previous input is + stored in ``metadata.replacedInput`` when this is set). + interrupt: The interrupted ``ToolRequestPart`` (e.g. from ``response.interrupts``). + resumed_metadata: Passed to the tool as :attr:`ToolRunContext.resumed_metadata`. + + Returns: + ``Part`` wrapping a ``ToolRequestPart`` for ``resume_restart`` / message history. + + Example: + ``pay_invoice.restart({**trp.tool_request.input, "confirmed": True}, interrupt=trp, resumed_metadata={"by": "bob"})`` + """ + tool_req = interrupt.tool_request + if tool_req.name != tool_name: + raise ValueError(f"Interrupt is for tool '{tool_req.name}', not '{tool_name}'") + + existing_meta = interrupt.metadata or {} + new_meta = dict(existing_meta) if existing_meta else {} + + # Mark as resumed + new_meta['resumed'] = resumed_metadata if resumed_metadata is not None else True + + # Store original input if replacing + new_input = tool_req.input + if replace_input is not None: + new_meta['replacedInput'] = tool_req.input + new_input = replace_input + + # Remove interrupt marker + if 'interrupt' in new_meta: + del new_meta['interrupt'] + + return Part( + root=ToolRequestPart( + tool_request=ToolRequest( + name=tool_req.name, + ref=tool_req.ref, + input=new_input, + ), + metadata=new_meta, + ) + ) + + wrapper.restart = restart # type: ignore[attr-defined] + return cast(Callable[P, T], wrapper) + + +def define_interrupt( + registry: Registry, + name: str, + *, + description: str | None = None, + request_metadata: dict[str, Any] | Callable[[Any], dict[str, Any]] | None = None, + input_schema: type[BaseModel] | dict[str, object] | None = None, + output_schema: type[BaseModel] | dict[str, object] | None = None, +) -> Callable[..., Any]: + """Register a tool that always interrupts execution. + + An interrupt tool is a special tool that always raises :class:`Interrupt` with + optional metadata. This is useful for explicit human-in-the-loop checkpoints. + For tools that sometimes run logic and sometimes interrupt, use :func:`define_tool` + and raise :class:`Interrupt` from the handler (or use ``ToolRunContext``). + + Args: + registry: The registry to register the interrupt tool in + name: Tool name (registry key) + description: Tool description shown to the model + request_metadata: Static metadata dict or ``(input) -> dict`` for the interrupt + input_schema: Optional input schema override (Pydantic model or JSON schema dict) + output_schema: Optional output schema override (Pydantic model or JSON schema dict) + + Returns: + The registered tool callable (same shape as :func:`define_tool`). + + Example: + def get_meta(input: dict) -> dict: + return {"action": input.get("action"), "requires_approval": True} + + confirm = define_interrupt( + registry, + "confirm", + description="Requires user approval", + request_metadata=get_meta, + ) + """ + + async def interrupt_wrapper(input: Any) -> Any: # noqa: ANN401 + meta = None + if callable(request_metadata): + meta = request_metadata(input) + elif request_metadata is not None: + meta = request_metadata + raise Interrupt(meta) + + return define_tool( + registry, + interrupt_wrapper, + name=name, + description=description, + input_schema=input_schema, + output_schema=output_schema, + ) diff --git a/py/packages/genkit/src/genkit/_core/_action.py b/py/packages/genkit/src/genkit/_core/_action.py index 4b985f90a1..13c361af3f 100644 --- a/py/packages/genkit/src/genkit/_core/_action.py +++ b/py/packages/genkit/src/genkit/_core/_action.py @@ -32,6 +32,7 @@ from typing_extensions import Never, TypeVar from genkit._core._channel import Channel +from genkit._core._schema import to_json_schema from genkit._core._compat import StrEnum from genkit._core._error import GenkitError from genkit._core._trace._path import build_path @@ -323,6 +324,9 @@ def __init__( description: str | None = None, metadata: dict[str, object] | None = None, span_metadata: dict[str, SpanAttributeValue] | None = None, + *, + input_schema: type[BaseModel] | dict[str, object] | None = None, + output_schema: type[BaseModel] | dict[str, object] | None = None, ) -> None: self._kind: ActionKind = kind self._name: str = name @@ -344,6 +348,30 @@ def __init__( n_action_args = len(action_args) self._fn = _make_tracing_wrapper(name, kind, span_metadata or {}, n_action_args, fn) self._initialize_io_schemas(action_args, arg_types, resolved_annotations, input_spec) + if input_schema is not None or output_schema is not None: + self._apply_tool_schema_overrides(input_schema=input_schema, output_schema=output_schema) + + def _apply_tool_schema_overrides( + self, + *, + input_schema: type[BaseModel] | dict[str, object] | None, + output_schema: type[BaseModel] | dict[str, object] | None, + ) -> None: + """Replace I/O JSON schemas (and input validation) when explicitly provided. + + Used for tools whose ``metadata_fn`` uses loose annotations (e.g. ``Any``) but + callers supply concrete Pydantic models or JSON Schema dicts. + """ + if input_schema is not None: + in_js = to_json_schema(input_schema) + self.input_schema = in_js + if isinstance(input_schema, dict): + self._input_type = None + else: + self._input_type = cast(TypeAdapter[InputT], TypeAdapter(input_schema)) + + if output_schema is not None: + self.output_schema = to_json_schema(output_schema) @property def kind(self) -> ActionKind: diff --git a/py/packages/genkit/src/genkit/_core/_registry.py b/py/packages/genkit/src/genkit/_core/_registry.py index b4731900be..1e8a9c2330 100644 --- a/py/packages/genkit/src/genkit/_core/_registry.py +++ b/py/packages/genkit/src/genkit/_core/_registry.py @@ -137,6 +137,9 @@ def register_action( description: str | None = None, metadata: dict[str, object] | None = None, span_metadata: dict[str, SpanAttributeValue] | None = None, + *, + input_schema: type[BaseModel] | dict[str, object] | None = None, + output_schema: type[BaseModel] | dict[str, object] | None = None, ) -> Action[InputT, OutputT, ChunkT]: """Register a new action with the registry. @@ -152,6 +155,9 @@ def register_action( description: Optional human-readable description of the action. metadata: Optional dictionary of metadata about the action. span_metadata: Optional dictionary of tracing span metadata. + input_schema: Optional explicit input JSON schema (Pydantic model or dict). + When set, overrides schemas inferred from ``metadata_fn``. + output_schema: Optional explicit output JSON schema (Pydantic model or dict). Returns: The newly created and registered Action instance. @@ -164,6 +170,8 @@ def register_action( description=description, metadata=metadata, span_metadata=span_metadata, + input_schema=input_schema, + output_schema=output_schema, ) action_typed = cast(Action[InputT, OutputT, ChunkT], action) with self._lock: diff --git a/py/packages/genkit/tests/genkit/ai/prompt_test.py b/py/packages/genkit/tests/genkit/ai/prompt_test.py index a3ecd5e49f..d095ef0a21 100644 --- a/py/packages/genkit/tests/genkit/ai/prompt_test.py +++ b/py/packages/genkit/tests/genkit/ai/prompt_test.py @@ -86,12 +86,12 @@ async def test_simple_prompt_with_override_config() -> None: my_prompt = ai.define_prompt(prompt='hi', config={'banana': True}) # New API: pass config via opts parameter - this MERGES with prompt config - response = await my_prompt(opts={'config': {'temperature': 12}}) + response = await my_prompt(config={'temperature': 12}) assert response.text == want_txt # New API: stream also uses opts - result = my_prompt.stream(opts={'config': {'temperature': 12}}) + result = my_prompt.stream(config={'temperature': 12}) assert (await result.response).text == want_txt @@ -241,10 +241,29 @@ async def test_prompt_rendering_dotprompt( """Test prompt rendering.""" ai, *_ = setup_test() - my_prompt = ai.define_prompt(**prompt) + my_prompt = ai.define_prompt( + model=prompt.get('model'), + config=prompt.get('config'), + description=prompt.get('description'), + input_schema=prompt.get('input_schema'), + system=prompt.get('system'), + prompt=prompt.get('prompt'), + messages=prompt.get('messages'), + output_format=prompt.get('output_format'), + output_content_type=prompt.get('output_content_type'), + output_instructions=prompt.get('output_instructions'), + output_constrained=prompt.get('output_constrained'), + max_turns=prompt.get('max_turns'), + return_tool_requests=prompt.get('return_tool_requests'), + metadata=prompt.get('metadata'), + tools=prompt.get('tools'), + tool_choice=prompt.get('tool_choice'), + use=prompt.get('use'), + docs=prompt.get('docs'), + ) # New API: use opts parameter to pass config and context - response = await my_prompt(input, opts={'config': input_option, 'context': context}) + response = await my_prompt(input, config=input_option, context=context) assert response.text == want_rendered @@ -488,7 +507,7 @@ async def test_config_merge_priority() -> None: # New API: runtime config is MERGED with prompt config # - temperature: 0.9 (from opts, overrides 0.5) # - banana: 'yellow' (from prompt, preserved) - rendered = await my_prompt.render(opts={'config': {'temperature': 0.9}}) + rendered = await my_prompt.render(config={'temperature': 0.9}) assert rendered.config is not None # Config is now a dict after merging @@ -510,7 +529,7 @@ async def test_opts_can_override_model() -> None: ) # Override model via opts - response = await my_prompt(opts={'model': 'programmableModel'}) + response = await my_prompt(model='programmableModel') # Should use programmableModel, not echoModel assert response.text == 'pm response' @@ -532,7 +551,7 @@ async def test_opts_can_append_messages() -> None: ] # Append conversation history via opts - rendered = await my_prompt.render(opts={'messages': history_messages}) + rendered = await my_prompt.render(messages=history_messages) # Should have: system + history (2) + user prompt = 4 messages assert len(rendered.messages) == 4 @@ -583,13 +602,11 @@ class OutputSchema(BaseModel): output_format='text', # Default to text ) - # Override output via opts + # Override output via kwargs rendered = await my_prompt.render( - opts={ - 'output': { - 'format': 'json', - 'schema': OutputSchema, - } + output={ + 'format': 'json', + 'schema': OutputSchema, } ) @@ -599,6 +616,42 @@ class OutputSchema(BaseModel): assert rendered.output.json_schema is not None +@pytest.mark.asyncio +async def test_executable_prompt_opts_removed() -> None: + """opts= has been removed; pass options as explicit kwargs (e.g. model=).""" + ai, *_ = setup_test() + + my_prompt = ai.define_prompt(prompt='hi', output_format='text') + + with pytest.raises(TypeError, match='opts'): + await my_prompt(opts={'model': 'echoModel'}) + + +@pytest.mark.asyncio +async def test_executable_prompt_input_positional_opts_as_kwargs() -> None: + """ExecutablePrompt: input is positional, opts via kwargs after *.""" + ai, *_ = setup_test() + + my_prompt = ai.define_prompt( + prompt='Recipe for {{cuisine}} {{dish}}', + output_format='text', + ) + + # input = positional (template vars), output = kwarg (opts) + rendered = await my_prompt.render( + {'cuisine': 'Italian', 'dish': 'pasta'}, + output={'format': 'text'}, + ) + + # Template vars from input should be in the rendered prompt + assert any('Italian' in str(m) for m in rendered.messages) + assert any('pasta' in str(m) for m in rendered.messages) + + # output kwarg should be respected + assert rendered.output is not None + assert rendered.output.format == 'text' + + # Tests for file-based prompt loading and two-action structure @pytest.mark.asyncio async def test_file_based_prompt_registers_two_actions() -> None: diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py index 0310c779f2..48967a4ed3 100644 --- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py @@ -15,11 +15,11 @@ from genkit import ( Document, Genkit, + Interrupt, Message, ModelResponse, ModelResponseChunk, - ToolRunContext, - tool_response, + respond_to_interrupt, ) from genkit._ai._formats._types import FormatDef, Formatter, FormatterConfig from genkit._ai._model import text_from_message @@ -316,6 +316,32 @@ async def test_generate_with_system_prompt_messages( assert (await stream_result.response).text == want_txt +@pytest.mark.asyncio +async def test_generate_with_prompt_as_tool(setup_test: SetupFixture) -> None: + """Test that ai.generate(tools=[prompt.as_tool()]) works (ACC-557). + + Previously, tools param only accepted list[str]; prompt.as_tool() returns Action, + causing Pydantic ValidationError. This test verifies Action is now accepted. + """ + ai, echo, *_ = setup_test + + sub_prompt = ai.define_prompt(name='subPrompt', prompt='Sub prompt') + prompt_action = await sub_prompt.as_tool() + + # Should NOT raise ValidationError - Action in tools is now accepted + response = await ai.generate( + model='echoModel', + prompt='Use the sub prompt', + tools=[prompt_action], + tool_choice=ToolChoice.REQUIRED, + ) + + assert response.text is not None + assert echo.last_request is not None + assert len(echo.last_request.tools) == 1 + assert echo.last_request.tools[0].name == 'subPrompt' + + @pytest.mark.asyncio async def test_generate_with_tools(setup_test: SetupFixture) -> None: """Test that the generate function with tools works.""" @@ -390,9 +416,9 @@ async def test_tool(input: ToolInput) -> int: return (input.value or 0) + 7 @ai.tool(name='test_interrupt') - async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: + async def test_interrupt(input: ToolInput) -> None: """The interrupt.""" - ctx.interrupt({'banana': 'yes please'}) + raise Interrupt({'banana': 'yes please'}) tool_request_msg = Message( Message( @@ -470,7 +496,7 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: Part(root=TextPart(text='call these tools')), Part( root=ToolRequestPart( - tool_request=ToolRequest(ref='123', name='test_interrupt', input={'value': 5}), + tool_request=ToolRequest(ref='123', name='test_interrupt', input=ToolInput(value=5)), metadata={'interrupt': {'banana': 'yes please'}}, ) ), @@ -503,9 +529,9 @@ async def test_tool(input: ToolInput) -> int: return (input.value or 0) + 7 @ai.tool(name='test_interrupt') - async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: + async def test_interrupt(input: ToolInput) -> None: """The interrupt.""" - ctx.interrupt({'banana': 'yes please'}) + raise Interrupt({'banana': 'yes please'}) tool_request_msg = Message( Message( @@ -542,7 +568,7 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: assert interrupted_response.tool_requests == [ Part( root=ToolRequestPart( - tool_request=ToolRequest(ref='123', name='test_interrupt', input={'value': 5}), + tool_request=ToolRequest(ref='123', name='test_interrupt', input=ToolInput(value=5)), metadata={'interrupt': {'banana': 'yes please'}}, ), ).root, @@ -565,7 +591,7 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: Part(root=TextPart(text='call these tools')), Part( root=ToolRequestPart( - tool_request=ToolRequest(ref='123', name='test_interrupt', input={'value': 5}), + tool_request=ToolRequest(ref='123', name='test_interrupt', input=ToolInput(value=5)), metadata={'interrupt': {'banana': 'yes please'}}, ) ), @@ -579,10 +605,12 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: ), ] + respond_wrapped = respond_to_interrupt({'bar': 2}, interrupt=interrupted_response.interrupts[0]) + assert isinstance(respond_wrapped, ToolResponsePart) response = await ai.generate( model='programmableModel', messages=interrupted_response.messages, - tool_responses=[tool_response(interrupted_response.interrupts[0], {'bar': 2})], + resume_respond=[respond_wrapped], tools=['test_tool', 'test_interrupt'], ) @@ -599,7 +627,7 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: Part(root=TextPart(text='call these tools')), Part( root=ToolRequestPart( - tool_request=ToolRequest(ref='123', name='test_interrupt', input={'value': 5}), + tool_request=ToolRequest(ref='123', name='test_interrupt', input=ToolInput(value=5)), metadata={'resolvedInterrupt': {'banana': 'yes please'}}, ) ), diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index 1fbda9417e..4b3966b7da 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -1082,6 +1082,15 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool: # Actually Google GenAI expects type=OBJECT for params usually. if not params: params = genai_types.Schema(type=genai_types.Type.OBJECT, properties={}) + # ACC-560: Gemini rejects scalar/array root schemas. Tool params must be OBJECT with + # properties. LLMs always send {"key": value} — wrap scalar/array in {"value": }. + elif self._is_scalar_or_array_root(params): + params = genai_types.Schema( + type=genai_types.Type.OBJECT, + properties={'value': params}, + required=['value'], + description=params.description, + ) function = genai_types.FunctionDeclaration( name=tool.name, @@ -1091,6 +1100,26 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool: ) return genai_types.Tool(function_declarations=[function]) + @staticmethod + def _is_scalar_or_array_root(schema: genai_types.Schema | None) -> bool: + """True if schema root is scalar or array (Gemini rejects these for tool params).""" + if not schema or not schema.type: + return False + t = schema.type + scalar_or_array = { + genai_types.Type.STRING, + genai_types.Type.NUMBER, + genai_types.Type.INTEGER, + genai_types.Type.BOOLEAN, + genai_types.Type.ARRAY, + } + if t not in scalar_or_array: + return False + # Object with properties is fine (already has properties) + if t == genai_types.Type.OBJECT: + return False + return True + def _convert_schema_property( self, input_schema: dict[str, object] | None, defs: dict[str, object] | None = None ) -> genai_types.Schema | None: @@ -1528,6 +1557,7 @@ async def _streaming_generate( cause=e, ) from e accumulated_content = [] + finish_reason: FinishReason | None = None async for response_chunk in generator: content = await self._contents_from_response(response_chunk) if content: # Only process if we have content @@ -1538,12 +1568,29 @@ async def _streaming_generate( role=Role.MODEL, ) ) + # Track finish_reason from last chunk (stream typically sends it in final chunk) + if response_chunk.candidates: + for c in response_chunk.candidates: + if c.finish_reason: + fr_name = c.finish_reason.name + if fr_name == 'STOP': + finish_reason = FinishReason.STOP + elif fr_name in ('MAX_TOKENS', 'MAX_OUTPUT_TOKENS'): + finish_reason = FinishReason.LENGTH + elif fr_name in ('SAFETY', 'RECITATION', 'BLOCKLIST', 'PROHIBITED_CONTENT', 'SPII'): + finish_reason = FinishReason.BLOCKED + elif fr_name == 'OTHER': + finish_reason = FinishReason.OTHER + else: + finish_reason = FinishReason.OTHER + break return ModelResponse( message=Message( role=Role.MODEL, content=accumulated_content, - ) + ), + finish_reason=finish_reason, ) @cached_property diff --git a/py/plugins/google-genai/test/models/googlegenai_gemini_test.py b/py/plugins/google-genai/test/models/googlegenai_gemini_test.py index 41e93ac0aa..5bf5602e59 100644 --- a/py/plugins/google-genai/test/models/googlegenai_gemini_test.py +++ b/py/plugins/google-genai/test/models/googlegenai_gemini_test.py @@ -481,6 +481,39 @@ def test_gemini_model__create_tool( assert isinstance(gemini_tool, genai_types.Tool) +def test_gemini_model__create_tool_wraps_scalar_input_schema( + gemini_model_instance: GeminiModel, +) -> None: + """ACC-560: Scalar/array root input schemas are wrapped in object for Gemini. + + Gemini rejects tool params with type=STRING etc. LLMs always send + {"key": value} — we wrap scalar schemas in {"value": }. + """ + scalar_string_schema = genai_types.Schema( + type=genai_types.Type.STRING, + description='Echo input', + ) + tool_defined = ToolDefinition( + name='echo', + description='Echo the input string', + input_schema={'type': 'string'}, + output_schema=None, + ) + with patch.object( + gemini_model_instance, + '_convert_schema_property', + return_value=scalar_string_schema, + ): + gemini_tool = gemini_model_instance._create_tool(tool_defined) + + decl = gemini_tool.function_declarations[0] + params = decl.parameters + assert params.type == genai_types.Type.OBJECT + assert 'value' in params.properties + assert params.properties['value'].type == genai_types.Type.STRING + assert params.required == ['value'] + + @pytest.mark.parametrize( 'input_schema, defs, expected_schema', [ diff --git a/py/pyproject.toml b/py/pyproject.toml index 756acdb349..411bf53dd1 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -271,7 +271,7 @@ select = [ # Samples are demo code and can use blocking I/O for simplicity "samples/*/*.py" = ["ANN", "D", "E402", "ASYNC"] "samples/*/**/*.py" = ["ANN", "D", "E402", "ASYNC"] -"samples/*/src/*.py" = ["ANN", "D", "E402", "ASYNC"] +"samples/*/src/*.py" = ["ANN", "D", "E402", "ASYNC", "T201"] [tool.ruff.lint.isort] diff --git a/py/samples/README.md b/py/samples/README.md index 95dabb187e..1548ebc99e 100644 --- a/py/samples/README.md +++ b/py/samples/README.md @@ -35,6 +35,6 @@ Dev UI: http://localhost:4000. Most samples need `GEMINI_API_KEY`. See [plugins/ | `middleware` | Observe or modify model requests | | `output-formats` | Text, enum, JSON, array, and JSONL outputs | | `prompts` | `.prompt` files, variants, helpers, and streaming | -| `tool-interrupts` | Pause a tool for human approval | +| `tool-interrupts` | Trivia (`respond_example.py`) and bank approval (`approval_example.py`) — interrupt + resume | | `tracing` | Watch spans appear in real time | | `vertexai-imagen` | Generate an image with Vertex AI Imagen | diff --git a/py/samples/tool-interrupts/README.md b/py/samples/tool-interrupts/README.md index f048c2aaca..ca1c5825dc 100644 --- a/py/samples/tool-interrupts/README.md +++ b/py/samples/tool-interrupts/README.md @@ -1,19 +1,33 @@ -# Tool Interrupts +# Tool interrupts -Human-in-the-loop: `ctx.interrupt()` and `tool_response()` let the model pause for user input, then continue. +Usually the generate loop calls a tool, your code runs, returns a value, and then restarts the generate loop again until it terminates on its own or hits a stopping condition. + +With an **interrupt**, the tool **doesn’t** finish that way: a tool can`raise Interrupt(...)` and **hand control back to your application**. Think of it as the tool saying “you handle this step”—collect input, call another service, enforce policy—**instead of** returning a final tool result in one shot. + +The Genkit SDK **stops that generation turn**, surfaces the pending tool call (with your payload on `**metadata["interrupt"]`**), and you `**generate` again** later with the **same `messages`** plus `**resume_respond**` or `resume_restart.`Either you **inject the tool outcome** (respond) or you **ask the SDK to run the tool again** with new input and metadata. + +## Samples + +`**respond_example.py`** — Trivia: the “tool” hands off to the CLI; your answer **is** the tool result (`respond_to_interrupt` + `resume_respond`). Prompt: `prompts/trivia_host_cli.prompt`. + +`**approval_example.py`** — Bank demo: `**y**` restarts the tool (`resume_restart`); `**n**` declines with respond (`resume_respond`). `USER_MESSAGE` is hardcoded; you only type y/n. Prompt: `prompts/bank_transfer_host_cli.prompt`. + +## Run + +`**GEMINI_API_KEY**` (Google AI plugin): ```bash export GEMINI_API_KEY=your-api-key uv sync -uv run src/main.py +uv run src/respond_example.py +uv run src/approval_example.py ``` -This launches a small interactive CLI trivia session. - -To inspect the same flow in Dev UI instead: +From repo root: ```bash -genkit start -- uv run src/main.py +uv run --directory py/samples/tool-interrupts python src/respond_example.py +uv run --directory py/samples/tool-interrupts python src/approval_example.py ``` -Try `play_trivia`. +Wire detail: `**MESSAGE_SHAPES.md**`. \ No newline at end of file diff --git a/py/samples/tool-interrupts/prompts/bank_transfer_host_cli.prompt b/py/samples/tool-interrupts/prompts/bank_transfer_host_cli.prompt new file mode 100644 index 0000000000..a197447ea8 --- /dev/null +++ b/py/samples/tool-interrupts/prompts/bank_transfer_host_cli.prompt @@ -0,0 +1,14 @@ +--- +model: googleai/gemini-3-flash-preview +--- + +You are a **bank assistant** in a demo CLI. Help users check balances in plain language, but **any outgoing transfer of money** must go through the approval tool first. + +When the user asks you to **send, wire, or transfer** funds to a person or account, call `request_transfer` with: +- `to_account`: who gets the money (name or masked account id) +- `amount_usd`: amount (e.g. `250.00`) +- `memo`: short reason (e.g. `rent`, `invoice #12`) + +Do **not** pretend the transfer already happened until the user has approved it in the CLI. Keep replies short. + +[user joined online banking] diff --git a/py/samples/tool-interrupts/prompts/trivia_host_cli.prompt b/py/samples/tool-interrupts/prompts/trivia_host_cli.prompt new file mode 100644 index 0000000000..c613dd049a --- /dev/null +++ b/py/samples/tool-interrupts/prompts/trivia_host_cli.prompt @@ -0,0 +1,9 @@ +--- +model: googleai/gemini-3-flash-preview +--- + +You are a trivia game host. Cheerfully greet the user when they first join and ask them for the theme of the trivia game. Suggest a few theme options, but they do not have to use them. + +When the user is ready for a question, call `present_questions` so the UI can show the question and multiple-choice answers. After the user answers, tell them if they were right or wrong. Be dramatic but brief. + +[user joined the game] diff --git a/py/samples/tool-interrupts/src/approval_example.py b/py/samples/tool-interrupts/src/approval_example.py new file mode 100644 index 0000000000..d1ecdc6a62 --- /dev/null +++ b/py/samples/tool-interrupts/src/approval_example.py @@ -0,0 +1,194 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""**Bank transfer approval** — human-in-the-loop before a transfer tool finishes. + +The model calls ``request_transfer``; the CLI asks **approve (y)** or **decline (n)**. +**Approve** → ``restart_interrupted_tool`` / ``resume_restart`` so the tool **runs again** +with :class:`~genkit.ToolRunContext.is_resumed`. **Decline** → ``respond_to_interrupt`` / +``resume_respond`` (no second tool run). + +Opening prompt, then one **canned user message** (no typing) so the model calls +``request_transfer``; you still answer **y/n** for approval. Run:: + + uv run src/approval_example.py + +For the trivia-only **respond** demo, see ``respond_example.py``. See README.md. +""" + +from pathlib import Path + +from pydantic import BaseModel, Field + +from genkit import ( + Genkit, + Interrupt, + ToolRunContext, + respond_to_interrupt, + restart_interrupted_tool, +) +from genkit.model import ModelResponse +from genkit.plugins.google_genai import GoogleAI # pyright: ignore[reportMissingImports] + +_PROMPTS_DIR = Path(__file__).resolve().parent.parent / 'prompts' + +_BAR = '=' * 52 +_RULE = '-' * 52 + +# Canned user line so the demo always triggers a transfer tool call without stdin. +USER_MESSAGE = ( + 'Please wire $250.00 to Jane Doe (account ending in 4521) for April rent.' +) + + +def _print_intro() -> None: + print(f'\n{_BAR}') + print(' Bank transfer demo — outgoing wires need your approval in the CLI.') + print(_BAR) + print(' 1) The banker speaks first.') + print(' 2) A scripted user message asks for a wire; the model calls the transfer tool.') + print(' 3) When asked y/n: yes = approve (tool runs again); no = decline.') + print(f'{_BAR}\n') + + +def _print_scripted_user_turn() -> None: + print(_RULE) + print('Scripted user message (see USER_MESSAGE in source):') + print(_RULE) + + +def _print_waiting_opening() -> None: + print('Starting: banker opening (please wait)...\n') + + +def _print_model_turn(label: str, r: ModelResponse) -> None: + print(f'\n[{label}]') + if r.text: + print(r.text) + + +def _print_transfer_approval_prompt(summary: str) -> None: + print('\n' + _BAR) + print(' TRANSFER APPROVAL — y = approve (rerun tool) | n = decline') + print(_BAR) + if summary: + print(f' {summary}') + + +def _print_unexpected_tool(name: str) -> None: + print(f'Unexpected tool: {name!r}') + + +ai = Genkit( + plugins=[GoogleAI()], + model='googleai/gemini-3-flash-preview', + prompt_dir=_PROMPTS_DIR, +) + + +class TransferRequest(BaseModel): + """Wire transfer the user asked for; shown again before approval.""" + + to_account: str = Field(description='recipient name or masked account identifier') + amount_usd: str = Field(description='amount as a string, e.g. 250.00') + memo: str = Field(default='', description='short reason (rent, invoice, gift, …)') + + +@ai.tool() +async def request_transfer(body: TransferRequest, ctx: ToolRunContext) -> dict: + """First run: interrupt for approval. After approval: return confirmation with metadata.""" + if not ctx.is_resumed(): + line = f"Wire ${body.amount_usd} to {body.to_account}" + if body.memo: + line = f'{line} — {body.memo}' + raise Interrupt( + { + 'summary': line, + 'to_account': body.to_account, + 'amount_usd': body.amount_usd, + 'memo': body.memo, + 'needs_approval': True, + } + ) + return {'status': 'confirmed', 'resumed': ctx.resumed_metadata} + + +async def interactive_restart_cli() -> None: + """Opening prompt, scripted user line, then transfer approval via ``request_transfer``.""" + + _print_intro() + + _print_waiting_opening() + response = await ai.prompt('bank_transfer_host_cli')() + messages = response.messages + _print_model_turn('Banker (opening)', response) + + _print_scripted_user_turn() + user_said = USER_MESSAGE + print(f'\nYou: {user_said}\n') + + response = await ai.generate( + messages=messages, + prompt=user_said, + tools=[request_transfer], + ) + messages = response.messages + _print_model_turn('Banker', response) + + while response.interrupts: + interrupt = response.interrupts[0] + name = interrupt.tool_request.name + if name != request_transfer.__name__: + _print_unexpected_tool(name) + return + + meta = interrupt.metadata.get('interrupt') if interrupt.metadata else True + summary = meta.get('summary', '') if isinstance(meta, dict) else '' + _print_transfer_approval_prompt(summary) + ans = input('Approve transfer? [y/N]: ').strip().lower() + + if ans in ('y', 'yes'): + restart = restart_interrupted_tool( + interrupt=interrupt, + tool=request_transfer, + resumed_metadata={'via': 'cli', 'path': 'restart'}, + ) + response = await ai.generate( + messages=messages, + resume_restart=restart, + tools=[request_transfer], + ) + else: + decline_response = respond_to_interrupt( + {'status': 'declined'}, + interrupt=interrupt, + metadata={'source': 'cli', 'path': 'respond_decline'}, + ) + response = await ai.generate( + messages=messages, + resume_respond=decline_response, + tools=[request_transfer], + ) + messages = response.messages + _print_model_turn('Banker (after your decision)', response) + + +async def main() -> None: + await interactive_restart_cli() + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/samples/tool-interrupts/src/main.py b/py/samples/tool-interrupts/src/main.py deleted file mode 100755 index 4ce87963eb..0000000000 --- a/py/samples/tool-interrupts/src/main.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Tool interrupts - Human-in-the-loop with ctx.interrupt() and tool_response(). See README.md.""" - -import os - -from pydantic import BaseModel, Field - -from genkit import ( - Genkit, - ToolRunContext, - tool_response, -) -from genkit.plugins.google_genai import GoogleAI -from genkit.plugins.google_genai.models import gemini - -ai = Genkit( - plugins=[GoogleAI()], - model=f'googleai/{gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW}', -) - - -class TriviaQuestions(BaseModel): - """Trivia questions.""" - - question: str = Field(description='the main question') - answers: list[str] = Field(description='list of multiple choice answers (typically 4), 1 correct 3 wrong') - - -@ai.tool() -async def present_questions(questions: TriviaQuestions, ctx: ToolRunContext) -> None: - """Presents questions to the user and responds with the selected answer.""" - ctx.interrupt(questions.model_dump()) - - -@ai.flow() -async def play_trivia(theme: str = 'Science') -> str: - """Plays a trivia game with the user.""" - response = await ai.generate( - prompt='You are a trivia game host. Cheerfully greet the user when they ' - + f'first join. The user has selected the theme: "{theme}". ' - + 'Call `present_questions` tool with questions and the tools will present ' - + 'the questions in a nice UI. The user will pick an answer and then you ' - + 'tell them if they were right or wrong. Be dramatic (but terse)! It is a ' - + 'show!\n\n[user joined the game]', - tools=['present_questions'], - ) - - # Check for interrupts and return the question to the user - if len(response.interrupts): - request = response.interrupts[0] - question_data = request.tool_request.input - if question_data: - # For a full interactive flow, you would typically: - # 1. Prompt the user for their answer here (e.g., using input()). - # 2. Call tool_response(request, user_answer) to resume the AI conversation. - # 3. Regenerate with the tool_response. - - # Prepend the greeting/text response if available - text_response = (response.text + '\n\n') if response.text else '' - question = question_data.get('question') - answers = question_data.get('answers') - return f'{text_response}INTERRUPTED: {question}\nAnswers: {answers}' - return 'INTERRUPTED: (No input data)' - - return response.text - - -async def main() -> None: - """Dev mode: return immediately; run_main keeps Dev UI alive. Standalone: interactive trivia CLI.""" - if os.environ.get('GENKIT_ENV') == 'dev': - return - try: - response = await ai.generate( - prompt='You are a trivia game host. Cheerfully greet the user when they ' - + 'first join and ask them for the theme of the trivia game. Suggest ' - + 'a few theme options, but they do not have to use them. When the user is ready, call ' - + '`present_questions` so the UI can show the question and answers. ' - + 'After the user answers, tell them if they were right or wrong. Be dramatic but brief.\n\n' - + '[user joined the game]', - ) - messages = response.messages - while True: - response = await ai.generate( - messages=messages, - prompt=input('Say: '), - tools=['present_questions'], - ) - messages = response.messages - if len(response.interrupts) > 0: - request = response.interrupts[0] - tr = tool_response(request, input('Your answer (number): ')) - response = await ai.generate( - messages=messages, - tool_responses=[tr], - tools=['present_questions'], - ) - messages = response.messages - except Exception as error: - print(f'Set GEMINI_API_KEY to a valid value before running this sample directly.\n{error}') # noqa: T201 - - -if __name__ == '__main__': - ai.run_main(main()) diff --git a/py/samples/tool-interrupts/src/respond_example.py b/py/samples/tool-interrupts/src/respond_example.py new file mode 100755 index 0000000000..83cecda494 --- /dev/null +++ b/py/samples/tool-interrupts/src/respond_example.py @@ -0,0 +1,163 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tool interrupts — trivia via ``present_questions`` and ``respond_to_interrupt``. + +``present_questions`` raises :class:`~genkit.Interrupt` with the question payload → the user +picks an answer → ``respond_to_interrupt(pick, interrupt=…, metadata=…)`` → second +``generate`` with ``resume_respond``. + +For **bank-transfer-style** tool approval (restart path), run ``approval_example.py`` instead. + +Run: ``uv run src/respond_example.py``. See README.md. +""" + +from pathlib import Path + +from pydantic import BaseModel, Field + +from genkit import ( + Genkit, + Interrupt, + ToolResponsePart, + respond_to_interrupt, +) +from genkit.model import ModelResponse +from genkit.plugins.google_genai import GoogleAI # pyright: ignore[reportMissingImports] + +_PROMPTS_DIR = Path(__file__).resolve().parent.parent / 'prompts' + +ai = Genkit( + plugins=[GoogleAI()], + model='googleai/gemini-3-flash-preview', + prompt_dir=_PROMPTS_DIR, +) + + +class TriviaQuestions(BaseModel): + """Payload passed into ``present_questions`` when the model calls the tool.""" + + question: str = Field(description='the main question') + answers: list[str] = Field( + description='list of multiple choice answers (typically 4), 1 correct 3 wrong', + ) + + +@ai.tool() +async def present_questions(questions: TriviaQuestions) -> None: + """Presents questions to the user and responds with the selected answer.""" + raise Interrupt(questions.model_dump(mode='json')) + + +DEMO_TOOLS = [present_questions] + + +async def interactive_trivia_cli() -> None: + """Run the CLI: opening turn, then chat with trivia interrupts (respond path).""" + + def show(label: str, r: ModelResponse) -> None: + """Print one model turn (what the host said).""" + print(f'\n[{label}]') + if r.text: + print(r.text) + + quit_words = frozenset({'q', 'quit', 'exit', 'bye'}) + bar = '=' * 52 + + print(f'\n{bar}') + print(' Tool interrupt demo — trivia (respond path)') + print(bar) + print(' 1) Host speaks first (you do not type yet).') + print(' 2) Then you chat one line at a time.') + print(' 3) When you see numbered answers, reply with a number.') + print(' Say quit / exit / q / bye anytime to stop.') + print(f'{bar}\n') + + print('Starting: host opening (please wait)...\n') + response = await ai.prompt('trivia_host_cli')() + messages = response.messages + show('Host (opening)', response) + + print('-' * 52) + print('Your turn — reply to the host above, or type quit to leave.') + print('-' * 52) + + while True: + user_said = input('\nYou: ').strip() + if user_said.lower() in quit_words: + print('Goodbye.') + return + if not user_said: + print('Empty line — type a message, or quit to exit.') + continue + + response = await ai.generate( + messages=messages, + prompt=user_said, + tools=DEMO_TOOLS, + ) + messages = response.messages + show('Host', response) + + while response.interrupts: + interrupt = response.interrupts[0] + name = interrupt.tool_request.name + + if name != present_questions.__name__: + print(f'Unexpected tool: {name!r}') + return + + payload = interrupt.tool_request.input + if payload is None: + print('Interrupt with no tool input.') + return + trivia = TriviaQuestions.model_validate(payload) + + n = len(trivia.answers) + print('\n' + bar) + print(' QUESTION — answer with a number') + print(bar) + print(trivia.question) + for i, ans in enumerate(trivia.answers, start=1): + print(f' {i}. {ans}') + print(f'Enter 1–{n}.') + + pick = input('Your choice (number): ').strip() + if pick.lower() in quit_words: + print('Goodbye.') + return + + interrupt_response = respond_to_interrupt( + pick, + interrupt=interrupt, + metadata={'source': 'cli', 'path': 'respond'}, + ) + assert isinstance(interrupt_response, ToolResponsePart) + response = await ai.generate( + messages=messages, + resume_respond=[interrupt_response], + tools=DEMO_TOOLS, + ) + messages = response.messages + show('Host (after your answer)', response) + + +async def main() -> None: + await interactive_trivia_cli() + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/uv.lock b/py/uv.lock index 8bd4e7a10b..adcfe31bcb 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -1569,7 +1569,7 @@ requires-dist = [ [[package]] name = "genkit-plugin-evaluators" -version = "0.1.0" +version = "0.5.1" source = { editable = "plugins/evaluators" } dependencies = [ { name = "genkit" }, From 5538bc5642d9965629bd0569e3bcea4818e8e3b9 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 26 Mar 2026 12:43:17 -0500 Subject: [PATCH 02/15] cleanup --- py/docs/index.md | 2 +- py/docs/types.md | 2 +- py/packages/genkit/src/genkit/__init__.py | 2 - .../genkit/src/genkit/_ai/_generate.py | 63 ++++++++++--------- py/packages/genkit/src/genkit/_ai/_tools.py | 60 +++++++----------- .../genkit/tests/genkit/veneer/veneer_test.py | 8 +-- 6 files changed, 63 insertions(+), 74 deletions(-) diff --git a/py/docs/index.md b/py/docs/index.md index 3023418f48..f7c2618977 100644 --- a/py/docs/index.md +++ b/py/docs/index.md @@ -35,7 +35,7 @@ ::: genkit.PublicError -::: genkit.ToolInterruptError +::: genkit.Interrupt ::: genkit.Message diff --git a/py/docs/types.md b/py/docs/types.md index 823b4c2259..239d1faf94 100644 --- a/py/docs/types.md +++ b/py/docs/types.md @@ -32,7 +32,7 @@ Types exported from genkit, genkit.model, genkit.embedder, genkit.plugin_api, an ::: genkit.PublicError -::: genkit.ToolInterruptError +::: genkit.Interrupt ::: genkit.Message diff --git a/py/packages/genkit/src/genkit/__init__.py b/py/packages/genkit/src/genkit/__init__.py index a5185d3d56..4546d1c63a 100644 --- a/py/packages/genkit/src/genkit/__init__.py +++ b/py/packages/genkit/src/genkit/__init__.py @@ -25,7 +25,6 @@ ) from genkit._ai._tools import ( Interrupt, - ToolInterruptError, ToolRunContext, respond_to_interrupt, restart_interrupted_tool, @@ -99,7 +98,6 @@ # Errors 'GenkitError', 'PublicError', - 'ToolInterruptError', 'Interrupt', 'respond_to_interrupt', 'restart_interrupted_tool', diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index a8839e0a9a..c307e9623a 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -35,7 +35,7 @@ ModelResponseChunk, ) from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources -from genkit._ai._tools import ToolInterruptError, run_tool_after_restart +from genkit._ai._tools import Interrupt, run_tool_after_restart from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._error import GenkitError from genkit._core._logger import get_logger @@ -649,14 +649,12 @@ async def resolve_tool_requests( tool_response_part, interrupt_part = await _resolve_tool_request(tool, tool_req_root) if tool_response_part: - # Extract the ToolResponsePart from the returned Part for _to_pending_response - if isinstance(tool_response_part.root, ToolResponsePart): - revised_model_message.content[i] = _to_pending_response(tool_req_root, tool_response_part.root) - response_parts.append(tool_response_part) + revised_model_message.content[i] = _to_pending_response(tool_req_root, tool_response_part) + response_parts.append(Part(root=tool_response_part)) if interrupt_part: has_interrupts = True - revised_model_message.content[i] = interrupt_part + revised_model_message.content[i] = Part(root=interrupt_part) i += 1 @@ -679,44 +677,49 @@ def _to_pending_response(request: ToolRequestPart, response: ToolResponsePart) - ) -async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart) -> tuple[Part | None, Part | None]: - """Execute a tool and return (response_part, interrupt_part).""" +def _interrupt_from_tool_exc(exc: BaseException) -> Interrupt | None: + """If ``exc`` is (or wraps) :class:`~genkit._ai._tools.Interrupt`, return that interrupt.""" + if isinstance(exc, Interrupt): + return exc + if isinstance(exc, GenkitError) and exc.cause is not None and isinstance(exc.cause, Interrupt): + return exc.cause + return None + + +async def _resolve_tool_request( + tool: Action, tool_request_part: ToolRequestPart +) -> tuple[ToolResponsePart | None, ToolRequestPart | None]: + """Execute a tool. + + Returns ``(ToolResponsePart, None)`` on success or ``(None, ToolRequestPart)`` when interrupted. + """ try: tool_response = (await tool.run(tool_request_part.tool_request.input)).response - # Part is a RootModel, so we pass content via 'root' parameter return ( - Part( - root=ToolResponsePart( - tool_response=ToolResponse( - name=tool_request_part.tool_request.name, - ref=tool_request_part.tool_request.ref, - output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, - ) + ToolResponsePart( + tool_response=ToolResponse( + name=tool_request_part.tool_request.name, + ref=tool_request_part.tool_request.ref, + output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, ) ), None, ) - except GenkitError as e: - if e.cause and isinstance(e.cause, ToolInterruptError): - interrupt_error = e.cause - # Part is a RootModel, so we pass content via 'root' parameter + except Exception as e: + intr = _interrupt_from_tool_exc(e) + if intr is not None: + payload: dict[str, Any] | bool = intr.data if intr.data else True tool_meta = tool_request_part.metadata or {} if not isinstance(tool_meta, dict): tool_meta = dict(tool_meta) return ( None, - Part( - root=ToolRequestPart( - tool_request=tool_request_part.tool_request, - metadata={ - **tool_meta, - 'interrupt': (interrupt_error.metadata if interrupt_error.metadata else True), - }, - ) + ToolRequestPart( + tool_request=tool_request_part.tool_request, + metadata={**tool_meta, 'interrupt': payload}, ), ) - - raise e + raise async def resolve_tool(registry: Registry, tool_name: str) -> Action: diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index d7bf09e3de..295d8752b0 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -62,21 +62,13 @@ def is_resumed(self) -> bool: return self.resumed_metadata is not None -# TODO(#4346): make this extend GenkitError once it has INTERRUPTED status -class ToolInterruptError(Exception): - """Controlled interruption of tool execution (e.g., to request user input).""" - - def __init__(self, metadata: dict[str, Any] | None = None) -> None: - """Initialize with optional interrupt metadata.""" - super().__init__() - self.metadata: dict[str, Any] = metadata or {} - - class Interrupt(Exception): """Exception for interrupting tool execution with user-facing API. - Use ``raise Interrupt(data)`` inside a tool to pause. Prefer this over raising - :class:`ToolInterruptError` directly. + Raise ``Interrupt(data)`` from a tool or from tool middleware (e.g. ``wrap_tool``). + Exceptions from ``tool.run`` are wrapped in :class:`~genkit._core._error.GenkitError` + with ``cause=Interrupt``; generation attaches interrupt metadata to the pending tool + request. To resume, use :func:`respond_to_interrupt` or :func:`restart_interrupted_tool` (or ``tool.restart`` on the registered tool callable). @@ -191,7 +183,7 @@ async def run_tool_after_restart(tool: Action[Any, Any, Any], restart_trp: ToolR try: tool_response = (await tool.run(restart_trp.tool_request.input)).response except GenkitError as e: - if e.cause and isinstance(e.cause, ToolInterruptError): + if e.cause and isinstance(e.cause, Interrupt): raise GenkitError( status='FAILED_PRECONDITION', message='Tool interrupted again during a restart execution; not supported yet.', @@ -257,29 +249,25 @@ def define_tool( async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 # Dynamic dispatch based on function signature - pyright can't verify ParamSpec here - try: - match len(input_spec.args): - case 0: - return await func_any() - case 1: - return await func_any(args[0]) - case 2: - # Read from context variables for resumed metadata - resumed_meta = _tool_resumed_metadata.get() - original_input = _tool_original_input.get() - return await func_any( - args[0], - ToolRunContext( - cast(ActionRunContext, args[1]), - resumed_metadata=resumed_meta, - original_input=original_input, - ), - ) - case _: - raise ValueError('tool must have 0-2 args...') - except Interrupt as e: - # Convert Interrupt to ToolInterruptError for compatibility with existing flow - raise ToolInterruptError(metadata=e.data) from e + match len(input_spec.args): + case 0: + return await func_any() + case 1: + return await func_any(args[0]) + case 2: + # Read from context variables for resumed metadata + resumed_meta = _tool_resumed_metadata.get() + original_input = _tool_original_input.get() + return await func_any( + args[0], + ToolRunContext( + cast(ActionRunContext, args[1]), + resumed_metadata=resumed_meta, + original_input=original_input, + ), + ) + case _: + raise ValueError('tool must have 0-2 args...') action = registry.register_action( name=tool_name, diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py index 48967a4ed3..fb7f253a72 100644 --- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py @@ -496,7 +496,7 @@ async def test_interrupt(input: ToolInput) -> None: Part(root=TextPart(text='call these tools')), Part( root=ToolRequestPart( - tool_request=ToolRequest(ref='123', name='test_interrupt', input=ToolInput(value=5)), + tool_request=ToolRequest(ref='123', name='test_interrupt', input={'value': 5}), metadata={'interrupt': {'banana': 'yes please'}}, ) ), @@ -568,7 +568,7 @@ async def test_interrupt(input: ToolInput) -> None: assert interrupted_response.tool_requests == [ Part( root=ToolRequestPart( - tool_request=ToolRequest(ref='123', name='test_interrupt', input=ToolInput(value=5)), + tool_request=ToolRequest(ref='123', name='test_interrupt', input={'value': 5}), metadata={'interrupt': {'banana': 'yes please'}}, ), ).root, @@ -591,7 +591,7 @@ async def test_interrupt(input: ToolInput) -> None: Part(root=TextPart(text='call these tools')), Part( root=ToolRequestPart( - tool_request=ToolRequest(ref='123', name='test_interrupt', input=ToolInput(value=5)), + tool_request=ToolRequest(ref='123', name='test_interrupt', input={'value': 5}), metadata={'interrupt': {'banana': 'yes please'}}, ) ), @@ -627,7 +627,7 @@ async def test_interrupt(input: ToolInput) -> None: Part(root=TextPart(text='call these tools')), Part( root=ToolRequestPart( - tool_request=ToolRequest(ref='123', name='test_interrupt', input=ToolInput(value=5)), + tool_request=ToolRequest(ref='123', name='test_interrupt', input={'value': 5}), metadata={'resolvedInterrupt': {'banana': 'yes please'}}, ) ), From 0f9ba86507d1f2de8143902495863856e3362d11 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 26 Mar 2026 15:15:04 -0500 Subject: [PATCH 03/15] fix bugs from testing tools w/ gemini with scalar input --- .../genkit/src/genkit/_ai/_generate.py | 8 +++- py/packages/genkit/src/genkit/_ai/_tools.py | 41 ++++++++++++++++++- .../plugins/google_genai/models/gemini.py | 5 ++- .../plugins/google_genai/models/utils.py | 6 ++- 4 files changed, 55 insertions(+), 5 deletions(-) diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index c307e9623a..a20abe5f51 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -35,7 +35,7 @@ ModelResponseChunk, ) from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources -from genkit._ai._tools import Interrupt, run_tool_after_restart +from genkit._ai._tools import Interrupt, run_tool_after_restart, unwrap_wrapped_scalar_tool_input_if_needed from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._error import GenkitError from genkit._core._logger import get_logger @@ -694,7 +694,11 @@ async def _resolve_tool_request( Returns ``(ToolResponsePart, None)`` on success or ``(None, ToolRequestPart)`` when interrupted. """ try: - tool_response = (await tool.run(tool_request_part.tool_request.input)).response + tool_in = unwrap_wrapped_scalar_tool_input_if_needed( + tool_request_part.tool_request.input, + tool.input_schema, + ) + tool_response = (await tool.run(tool_in)).response return ( ToolResponsePart( tool_response=ToolResponse( diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index 295d8752b0..862f0d54ca 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -37,6 +37,41 @@ _tool_original_input: ContextVar[Any | None] = ContextVar('tool_original_input', default=None) +def _json_schema_root_is_scalar_or_array(schema: dict[str, object] | None) -> bool: + """Return True if the JSON Schema root is a scalar or array type.""" + if not schema: + return False + t = schema.get('type') + if isinstance(t, str): + tl = t.lower() + return tl in ('string', 'number', 'integer', 'boolean', 'array') + return False + + +def unwrap_wrapped_scalar_tool_input_if_needed( + input_payload: Any, # noqa: ANN401 + input_schema: dict[str, object] | None, +) -> Any: + """Unwrap ``{"value": x}`` before calling a tool whose JSON Schema root is scalar/array. + + The Google Genai Gemini plugin wraps scalar/array roots as ``{"value": }`` + at declaration time so the Gemini API accepts them. The model then sends arguments + in that same shape. This helper strips the wrapper so the tool handler receives the + bare value it was defined with. + + Note: ``ToolRequest.input`` is NOT modified — this unwrap happens only at call time + so that message history keeps the original wire format (required for subsequent + Gemini turns where ``FunctionCall.args`` must be a dict). + """ + if not _json_schema_root_is_scalar_or_array(input_schema): + return input_payload + if not isinstance(input_payload, dict): + return input_payload + if set(input_payload.keys()) != {'value'}: + return input_payload + return input_payload['value'] + + class ToolRunContext(ActionRunContext): """Tool execution context with interrupt support.""" @@ -181,7 +216,11 @@ async def run_tool_after_restart(tool: Action[Any, Any, Any], restart_trp: ToolR token_input = _tool_original_input.set(original_input) try: try: - tool_response = (await tool.run(restart_trp.tool_request.input)).response + tool_in = unwrap_wrapped_scalar_tool_input_if_needed( + restart_trp.tool_request.input, + tool.input_schema, + ) + tool_response = (await tool.run(tool_in)).response except GenkitError as e: if e.cause and isinstance(e.cause, Interrupt): raise GenkitError( diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index 4b3966b7da..2a6c6a7b44 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -1363,7 +1363,10 @@ async def generate(self, request: ModelRequest, ctx: ActionRunContext) -> ModelR ) else: response = await self._generate( - request_contents=request_contents, request_cfg=request_cfg, model_name=model_name, client=client + request_contents=request_contents, + request_cfg=request_cfg, + model_name=model_name, + client=client, ) response.usage = self._create_usage_stats(request=request, response=response) diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py index 08fd486006..32babfa0a0 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py @@ -278,7 +278,11 @@ def _to_gemini_custom(cls, part: Part | DocumentPart) -> genai.types.Part: return genai.types.Part() @classmethod - def from_gemini(cls, part: genai.types.Part, ref: str | None = None) -> Part: + def from_gemini( + cls, + part: genai.types.Part, + ref: str | None = None, + ) -> Part: """Maps a Gemini Part back to a Genkit Part. This method inspects the type of the Gemini Part and converts it into From 57510fbfa5d2d8c219036facd8f144b586a67c32 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 26 Mar 2026 18:48:54 -0500 Subject: [PATCH 04/15] fixes --- py/packages/genkit/src/genkit/_ai/_aio.py | 3 +- py/packages/genkit/src/genkit/_ai/_prompt.py | 18 ++++----- py/packages/genkit/src/genkit/_ai/_tools.py | 38 ++++++++++--------- .../genkit/src/genkit/_core/_action.py | 2 +- .../genkit/tests/genkit/ai/prompt_test.py | 4 +- .../genkit/tests/genkit/veneer/veneer_test.py | 6 ++- .../tool-interrupts/src/approval_example.py | 22 +++++------ 7 files changed, 46 insertions(+), 47 deletions(-) diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index b8fa6076d3..f403a4fd9d 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -279,7 +279,7 @@ def define_interrupt( input_schema: type[BaseModel] | dict[str, object] | None = None, output_schema: type[BaseModel] | dict[str, object] | None = None, description: str | None = None, - ) -> Callable[P, T]: + ) -> Callable[..., Any]: """Register an interrupt tool that always pauses for user input. Args: @@ -299,7 +299,6 @@ def define_interrupt( description='Ask the user a question', ) """ - return define_interrupt( self.registry, name, diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index 664fab6f3f..787debcd90 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -20,13 +20,11 @@ import asyncio import os import weakref -from collections.abc import AsyncIterable, Awaitable, Callable, Sequence +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass from pathlib import Path from typing import Any, ClassVar, Generic, TypedDict, TypeVar, cast -from typing_extensions import Unpack - from dotpromptz.typing import ( DataArgument, PromptFunction, @@ -34,6 +32,8 @@ PromptMetadata, ) from pydantic import BaseModel, ConfigDict +from typing_extensions import Unpack + from genkit._ai._generate import ( ToolReference, generate_action, @@ -143,17 +143,17 @@ class PromptGenerateOptions(TypedDict, total=False): _PROMPT_GENERATE_OPTION_KEYS: frozenset[str] = frozenset(PromptGenerateOptions.__annotations__) -def _coerce_prompt_opts(opts: PromptGenerateOptions) -> PromptGenerateOptions: +def _coerce_prompt_opts(opts: Mapping[str, Any]) -> PromptGenerateOptions: """Build effective opts from a kwargs mapping: drop keys whose value is None. Rejects unknown keys at runtime (Unpack only enforces this for type checkers). """ - raw = cast(dict[str, Any], opts) + raw = dict(opts) unknown = set(raw) - _PROMPT_GENERATE_OPTION_KEYS if unknown: if 'opts' in unknown: raise TypeError( - "Passing a combined `opts` dict is not supported; use keyword arguments " + 'Passing a combined `opts` dict is not supported; use keyword arguments ' '(e.g. model=..., config=...) matching PromptGenerateOptions.' ) sorted_unknown = ', '.join(sorted(unknown)) @@ -354,7 +354,7 @@ async def __call__( Args: input: Template variables for rendering. """ - effective_opts = _coerce_prompt_opts(cast(PromptGenerateOptions, opts)) + effective_opts = _coerce_prompt_opts(opts) return await self._call_impl(input, effective_opts) async def _call_impl( @@ -523,7 +523,7 @@ def stream( **opts: Unpack[PromptGenerateOptions], ) -> ModelStreamResponse[OutputT]: """Stream the prompt execution, returning (stream, response_future).""" - effective_opts = _coerce_prompt_opts(cast(PromptGenerateOptions, opts)) + effective_opts = _coerce_prompt_opts(opts) channel: Channel[ModelResponseChunk, ModelResponse[OutputT]] = Channel(timeout=timeout) stream_opts: PromptGenerateOptions = { **effective_opts, @@ -545,7 +545,7 @@ async def render( Same keyword options as :meth:`__call__` (see :class:`PromptGenerateOptions`). """ await self._ensure_resolved() - coerced = _coerce_prompt_opts(cast(PromptGenerateOptions, opts)) + coerced = _coerce_prompt_opts(opts) return await self._render_impl(input, coerced) async def as_tool(self) -> Action: diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index 862f0d54ca..3f295f2eba 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -34,7 +34,8 @@ # Context variables for propagating resumed metadata to tools _tool_resumed_metadata: ContextVar[dict[str, Any] | None] = ContextVar('tool_resumed_metadata', default=None) -_tool_original_input: ContextVar[Any | None] = ContextVar('tool_original_input', default=None) +# Stashed copy of tool_request.input when restart replaces input (JSON; shape is per tool). +_tool_original_input: ContextVar[Any | None] = ContextVar('tool_original_input', default=None) # noqa: ANN401 def _json_schema_root_is_scalar_or_array(schema: dict[str, object] | None) -> bool: @@ -49,9 +50,9 @@ def _json_schema_root_is_scalar_or_array(schema: dict[str, object] | None) -> bo def unwrap_wrapped_scalar_tool_input_if_needed( - input_payload: Any, # noqa: ANN401 + input_payload: Any, # noqa: ANN401 - wire JSON from model; schema varies per tool input_schema: dict[str, object] | None, -) -> Any: +) -> Any: # noqa: ANN401 - same payload after optional {"value": x} unwrap """Unwrap ``{"value": x}`` before calling a tool whose JSON Schema root is scalar/array. The Google Genai Gemini plugin wraps scalar/array roots as ``{"value": }`` @@ -79,7 +80,7 @@ def __init__( self, ctx: ActionRunContext, resumed_metadata: dict[str, Any] | None = None, - original_input: Any = None, + original_input: Any = None, # noqa: ANN401 - prior tool_request.input when replacing on restart ) -> None: """Initialize from parent ActionRunContext. @@ -97,7 +98,7 @@ def is_resumed(self) -> bool: return self.resumed_metadata is not None -class Interrupt(Exception): +class Interrupt(Exception): # noqa: N818 - public Genkit name (JS/Go); not renamed *Error for style """Exception for interrupting tool execution with user-facing API. Raise ``Interrupt(data)`` from a tool or from tool middleware (e.g. ``wrap_tool``). @@ -122,7 +123,7 @@ def __init__(self, data: dict[str, Any] | None = None) -> None: def _tool_response_part( interrupt: ToolRequestPart, - output: Any, + output: Any, # noqa: ANN401 - arbitrary tool/interrupt reply payload (JSON) metadata: dict[str, Any] | None = None, ) -> ToolResponsePart: """Build a ``ToolResponsePart`` for an interrupted tool request (interrupt reply channel).""" @@ -139,7 +140,7 @@ def _tool_response_part( def respond_to_interrupt( - response: Any, + response: Any, # noqa: ANN401 - user reply or tool output for resume_respond *, interrupt: ToolRequestPart, metadata: dict[str, Any] | None = None, @@ -157,7 +158,7 @@ def respond_to_interrupt( def restart_interrupted_tool( - replace_input: Any | None = None, # noqa: ANN401 + replace_input: Any | None = None, # noqa: ANN401 - new tool_request.input JSON or None to keep prior *, interrupt: ToolRequestPart, tool: Callable[..., Any], @@ -181,9 +182,7 @@ def restart_interrupted_tool( """ restart = getattr(tool, 'restart', None) if restart is None: - raise TypeError( - f'{tool!r} has no restart method; pass a tool from define_tool or Genkit.tool' - ) + raise TypeError(f'{tool!r} has no restart method; pass a tool from define_tool or Genkit.tool') part = restart( replace_input, interrupt=interrupt, @@ -286,8 +285,8 @@ def define_tool( func_any = cast(Callable[..., Any], func) - async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 - # Dynamic dispatch based on function signature - pyright can't verify ParamSpec here + async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 - arity dispatch; args/return follow registered tool + # Dynamic dispatch by arity; payload types follow the registered tool (not expressible here). match len(input_spec.args): case 0: return await func_any() @@ -320,12 +319,13 @@ async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any: # noqa: ANN401 + # Return type is the tool's declared output; ParamSpec preserves call shape only. action_any = cast(Any, action) return (await action_any.run(*args, **kwargs)).response # Add restart method to the wrapper (use respond_to_interrupt for interrupt replies). def restart( - replace_input: Any | None = None, # noqa: ANN401 + replace_input: Any | None = None, # noqa: ANN401 - same as restart_interrupted_tool.replace_input *, interrupt: ToolRequestPart, resumed_metadata: dict[str, Any] | None = None, @@ -342,14 +342,15 @@ def restart( ``Part`` wrapping a ``ToolRequestPart`` for ``resume_restart`` / message history. Example: - ``pay_invoice.restart({**trp.tool_request.input, "confirmed": True}, interrupt=trp, resumed_metadata={"by": "bob"})`` + ``pay_invoice.restart({**trp.tool_request.input, "confirmed": True}, interrupt=trp,`` + ``resumed_metadata={"by": "bob"})`` """ tool_req = interrupt.tool_request if tool_req.name != tool_name: raise ValueError(f"Interrupt is for tool '{tool_req.name}', not '{tool_name}'") existing_meta = interrupt.metadata or {} - new_meta = dict(existing_meta) if existing_meta else {} + new_meta: dict[str, Any] = dict(existing_meta) if existing_meta else {} # Mark as resumed new_meta['resumed'] = resumed_metadata if resumed_metadata is not None else True @@ -385,7 +386,7 @@ def define_interrupt( name: str, *, description: str | None = None, - request_metadata: dict[str, Any] | Callable[[Any], dict[str, Any]] | None = None, + request_metadata: dict[str, Any] | Callable[[Any], dict[str, Any]] | None = None, # noqa: ANN401 input_schema: type[BaseModel] | dict[str, object] | None = None, output_schema: type[BaseModel] | dict[str, object] | None = None, ) -> Callable[..., Any]: @@ -419,7 +420,8 @@ def get_meta(input: dict) -> dict: ) """ - async def interrupt_wrapper(input: Any) -> Any: # noqa: ANN401 + async def interrupt_wrapper(input: Any) -> Any: # noqa: ANN401 - wire JSON args; never returns (raises Interrupt) + # Interrupt tools accept arbitrary JSON args like any tool. meta = None if callable(request_metadata): meta = request_metadata(input) diff --git a/py/packages/genkit/src/genkit/_core/_action.py b/py/packages/genkit/src/genkit/_core/_action.py index 13c361af3f..02d7bf0ad6 100644 --- a/py/packages/genkit/src/genkit/_core/_action.py +++ b/py/packages/genkit/src/genkit/_core/_action.py @@ -32,9 +32,9 @@ from typing_extensions import Never, TypeVar from genkit._core._channel import Channel -from genkit._core._schema import to_json_schema from genkit._core._compat import StrEnum from genkit._core._error import GenkitError +from genkit._core._schema import to_json_schema from genkit._core._trace._path import build_path from genkit._core._tracing import tracer diff --git a/py/packages/genkit/tests/genkit/ai/prompt_test.py b/py/packages/genkit/tests/genkit/ai/prompt_test.py index d095ef0a21..24c31a3dc7 100644 --- a/py/packages/genkit/tests/genkit/ai/prompt_test.py +++ b/py/packages/genkit/tests/genkit/ai/prompt_test.py @@ -19,7 +19,7 @@ import tempfile from pathlib import Path -from typing import Any +from typing import Any, cast from unittest.mock import ANY, MagicMock, patch import pytest @@ -624,7 +624,7 @@ async def test_executable_prompt_opts_removed() -> None: my_prompt = ai.define_prompt(prompt='hi', output_format='text') with pytest.raises(TypeError, match='opts'): - await my_prompt(opts={'model': 'echoModel'}) + await my_prompt(**cast(Any, {'opts': {'model': 'echoModel'}})) @pytest.mark.asyncio diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py index fb7f253a72..21363b76b6 100644 --- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py @@ -338,8 +338,10 @@ async def test_generate_with_prompt_as_tool(setup_test: SetupFixture) -> None: assert response.text is not None assert echo.last_request is not None - assert len(echo.last_request.tools) == 1 - assert echo.last_request.tools[0].name == 'subPrompt' + tools = echo.last_request.tools + assert tools is not None + assert len(tools) == 1 + assert tools[0].name == 'subPrompt' @pytest.mark.asyncio diff --git a/py/samples/tool-interrupts/src/approval_example.py b/py/samples/tool-interrupts/src/approval_example.py index d1ecdc6a62..7ed6b684d5 100644 --- a/py/samples/tool-interrupts/src/approval_example.py +++ b/py/samples/tool-interrupts/src/approval_example.py @@ -49,9 +49,7 @@ _RULE = '-' * 52 # Canned user line so the demo always triggers a transfer tool call without stdin. -USER_MESSAGE = ( - 'Please wire $250.00 to Jane Doe (account ending in 4521) for April rent.' -) +USER_MESSAGE = 'Please wire $250.00 to Jane Doe (account ending in 4521) for April rent.' def _print_intro() -> None: @@ -111,18 +109,16 @@ class TransferRequest(BaseModel): async def request_transfer(body: TransferRequest, ctx: ToolRunContext) -> dict: """First run: interrupt for approval. After approval: return confirmation with metadata.""" if not ctx.is_resumed(): - line = f"Wire ${body.amount_usd} to {body.to_account}" + line = f'Wire ${body.amount_usd} to {body.to_account}' if body.memo: line = f'{line} — {body.memo}' - raise Interrupt( - { - 'summary': line, - 'to_account': body.to_account, - 'amount_usd': body.amount_usd, - 'memo': body.memo, - 'needs_approval': True, - } - ) + raise Interrupt({ + 'summary': line, + 'to_account': body.to_account, + 'amount_usd': body.amount_usd, + 'memo': body.memo, + 'needs_approval': True, + }) return {'status': 'confirmed', 'resumed': ctx.resumed_metadata} From c33cc1680b058b29047c4945bd7b24af444e4eb9 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 27 Mar 2026 09:41:13 -0500 Subject: [PATCH 05/15] refactor input_schema override --- py/packages/genkit/src/genkit/__init__.py | 2 - py/packages/genkit/src/genkit/_ai/_aio.py | 30 +++++------ .../genkit/src/genkit/_ai/_generate.py | 23 ++++---- py/packages/genkit/src/genkit/_ai/_prompt.py | 7 ++- py/packages/genkit/src/genkit/_ai/_tools.py | 53 ++++++++++++------- .../genkit/src/genkit/_core/_action.py | 33 ++++-------- .../genkit/src/genkit/_core/_registry.py | 8 ++- .../plugins/google_genai/models/gemini.py | 2 +- .../test/models/googlegenai_gemini_test.py | 2 +- 9 files changed, 78 insertions(+), 82 deletions(-) diff --git a/py/packages/genkit/src/genkit/__init__.py b/py/packages/genkit/src/genkit/__init__.py index 4546d1c63a..4b756ed117 100644 --- a/py/packages/genkit/src/genkit/__init__.py +++ b/py/packages/genkit/src/genkit/__init__.py @@ -17,7 +17,6 @@ """Genkit — Build AI-powered applications.""" from genkit._ai._aio import ActionKind, ActionRunContext, Genkit -from genkit._ai._generate import ToolReference from genkit._ai._prompt import ( ExecutablePrompt, ModelStreamResponse, @@ -134,7 +133,6 @@ 'ExecutablePrompt', 'PromptGenerateOptions', 'ToolRunContext', - 'ToolReference', 'ModelRequest', 'ModelResponse', 'ModelResponseChunk', diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index f403a4fd9d..831ef4bc84 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -45,7 +45,7 @@ ) from genkit._ai._formats import built_in_formats from genkit._ai._formats._types import FormatDef -from genkit._ai._generate import ToolReference, define_generate_action, generate_action +from genkit._ai._generate import define_generate_action, generate_action from genkit._ai._model import ( Message, ModelConfig, @@ -277,7 +277,6 @@ def define_interrupt( name: str, *, input_schema: type[BaseModel] | dict[str, object] | None = None, - output_schema: type[BaseModel] | dict[str, object] | None = None, description: str | None = None, ) -> Callable[..., Any]: """Register an interrupt tool that always pauses for user input. @@ -285,7 +284,6 @@ def define_interrupt( Args: name: Tool name input_schema: Optional input schema (Pydantic model or JSON schema dict) - output_schema: Optional output schema (Pydantic model or JSON schema dict) description: Tool description Returns: @@ -295,7 +293,6 @@ def define_interrupt( ask_user = ai.define_interrupt( name='ask_user', input_schema=Question, - output_schema=Answer, description='Ask the user a question', ) """ @@ -304,7 +301,6 @@ def define_interrupt( name, description=description, input_schema=input_schema, - output_schema=output_schema, ) def define_evaluator( @@ -430,7 +426,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -458,7 +454,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -486,7 +482,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -514,7 +510,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -540,7 +536,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -778,7 +774,7 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -805,7 +801,7 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -830,7 +826,7 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -886,7 +882,7 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -914,7 +910,7 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -940,7 +936,7 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -1193,7 +1189,7 @@ async def generate_operation( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, config: dict[str, object] | ModelConfig | None = None, diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index a20abe5f51..3ad5c76b2f 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -20,7 +20,7 @@ import inspect import re from collections.abc import Awaitable, Callable, Sequence -from typing import Any, TypeAlias, cast +from typing import Any, cast from pydantic import BaseModel @@ -56,18 +56,19 @@ logger = get_logger(__name__) -# Allowed ``tools=`` values for :meth:`~genkit.Genkit.generate`: -# - tool name (``str``) -# - :class:`~genkit.Action` (e.g. from prompt ``as_tool()``) -# - a function decorated with ``@ai.tool()`` (the same object you registered as a tool) -# Use ``Sequence`` instead of ``list``: type checkers treat ``list`` as strict about the -# exact item type, so ``[my_tool_fn]`` can fail to match ``list[str | Action | ...]`` even -# though it works at runtime. ``Sequence`` does not have that problem. -ToolReference: TypeAlias = str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]] +def tools_to_action_refs( + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None, +) -> list[str] | None: + """Normalize tool arguments to registry names for :class:`GenerateActionOptions`. -def tools_to_action_refs(tools: Sequence[ToolReference] | None) -> list[str] | None: - """Normalize tool arguments to registry names for :class:`GenerateActionOptions`.""" + Each item may be: a tool name (``str``), an :class:`~genkit.Action` (e.g. from prompt + ``as_tool()``), or an async function registered as a tool (same object as ``@ai.tool()``). + + Use :class:`~collections.abc.Sequence` in call sites instead of ``list``: type checkers + treat ``list`` as invariant in the item type, so ``[my_tool_fn]`` may not match + ``list[str | Action | ...]`` even though it works at runtime. + """ if tools is None: return None refs: list[str] = [] diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index 787debcd90..d07517b01a 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -35,7 +35,6 @@ from typing_extensions import Unpack from genkit._ai._generate import ( - ToolReference, generate_action, resolve_tool, to_tool_definition, @@ -124,7 +123,7 @@ class PromptGenerateOptions(TypedDict, total=False): config: dict[str, Any] | ModelConfig | None messages: list[Message] | None docs: list[Document] | None - tools: Sequence[ToolReference] | None + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None resources: list[str] | None tool_choice: ToolChoice | None output: OutputOptions | None @@ -232,7 +231,7 @@ class PromptConfig(BaseModel): max_turns: int | None = None return_tool_requests: bool | None = None metadata: dict[str, Any] | None = None - tools: Sequence[ToolReference] | None = None + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None tool_choice: ToolChoice | None = None use: list[ModelMiddleware] | None = None docs: list[Document] | None = None @@ -264,7 +263,7 @@ def __init__( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, Any] | None = None, - tools: Sequence[ToolReference] | None = None, + tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index 3f295f2eba..f72c091e3f 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -98,7 +98,7 @@ def is_resumed(self) -> bool: return self.resumed_metadata is not None -class Interrupt(Exception): # noqa: N818 - public Genkit name (JS/Go); not renamed *Error for style +class Interrupt(Exception): # noqa: N818 - public Genkit name; not renamed *Error for style """Exception for interrupting tool execution with user-facing API. Raise ``Interrupt(data)`` from a tool or from tool middleware (e.g. ``wrap_tool``). @@ -252,27 +252,22 @@ def _get_func_description(func: Callable[..., Any], description: str | None = No return '' -def define_tool( +def _define_tool( registry: Registry, func: Callable[P, T], name: str | None = None, description: str | None = None, *, input_schema: type[BaseModel] | dict[str, object] | None = None, - output_schema: type[BaseModel] | dict[str, object] | None = None, ) -> Callable[P, T]: """Register a function as a tool. - Args: - registry: The registry to register the tool in. - func: The async function to register as a tool. Must be a coroutine function. - name: Optional name for the tool. Defaults to the function name. - description: Optional description. Defaults to the function's docstring. - input_schema: Optional override for tool input JSON schema / validation (Pydantic model or dict). - output_schema: Optional override for tool output JSON schema (Pydantic model or dict). + Normally, the input_schema and output_schem are inferred from func. However, + in some cases, like define_interrupt, the app developer doesn't have a way to + express the input schema in the func signature. - Raises: - TypeError: If func is not an async function. + In that case, the app developer can pass in an input_schema to override the inferred schema. + This will ensure that the model requesting the tool will see the correct input shape. """ # All Python functions have __name__, but ty is strict about Callable protocol if not inspect.iscoroutinefunction(func): @@ -313,9 +308,9 @@ async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 - arity dispatch; description=tool_description, fn=tool_fn_wrapper, metadata_fn=func, - input_schema=input_schema, - output_schema=output_schema, ) + if input_schema is not None: + action._override_input_schema(input_schema) @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any: # noqa: ANN401 @@ -381,6 +376,28 @@ def restart( return cast(Callable[P, T], wrapper) +def define_tool( + registry: Registry, + func: Callable[P, T], + name: str | None = None, + description: str | None = None, +) -> Callable[P, T]: + """Register a function as a tool. + + Tool input/output JSON Schemas are inferred from ``func`` (first parameter and return type). + + Args: + registry: The registry to register the tool in. + func: The async function to register as a tool. Must be a coroutine function. + name: Optional name for the tool. Defaults to the function name. + description: Optional description. Defaults to the function's docstring. + + Raises: + TypeError: If func is not an async function. + """ + return _define_tool(registry, func, name, description) + + def define_interrupt( registry: Registry, name: str, @@ -388,7 +405,6 @@ def define_interrupt( description: str | None = None, request_metadata: dict[str, Any] | Callable[[Any], dict[str, Any]] | None = None, # noqa: ANN401 input_schema: type[BaseModel] | dict[str, object] | None = None, - output_schema: type[BaseModel] | dict[str, object] | None = None, ) -> Callable[..., Any]: """Register a tool that always interrupts execution. @@ -402,8 +418,8 @@ def define_interrupt( name: Tool name (registry key) description: Tool description shown to the model request_metadata: Static metadata dict or ``(input) -> dict`` for the interrupt - input_schema: Optional input schema override (Pydantic model or JSON schema dict) - output_schema: Optional output schema override (Pydantic model or JSON schema dict) + input_schema: Optional wire input schema (Pydantic model or JSON schema dict). The + interrupt handler is typed as ``Any``; pass this so the model sees a concrete shape. Returns: The registered tool callable (same shape as :func:`define_tool`). @@ -429,11 +445,10 @@ async def interrupt_wrapper(input: Any) -> Any: # noqa: ANN401 - wire JSON args meta = request_metadata raise Interrupt(meta) - return define_tool( + return _define_tool( registry, interrupt_wrapper, name=name, description=description, input_schema=input_schema, - output_schema=output_schema, ) diff --git a/py/packages/genkit/src/genkit/_core/_action.py b/py/packages/genkit/src/genkit/_core/_action.py index 02d7bf0ad6..98ac692824 100644 --- a/py/packages/genkit/src/genkit/_core/_action.py +++ b/py/packages/genkit/src/genkit/_core/_action.py @@ -324,9 +324,6 @@ def __init__( description: str | None = None, metadata: dict[str, object] | None = None, span_metadata: dict[str, SpanAttributeValue] | None = None, - *, - input_schema: type[BaseModel] | dict[str, object] | None = None, - output_schema: type[BaseModel] | dict[str, object] | None = None, ) -> None: self._kind: ActionKind = kind self._name: str = name @@ -348,30 +345,22 @@ def __init__( n_action_args = len(action_args) self._fn = _make_tracing_wrapper(name, kind, span_metadata or {}, n_action_args, fn) self._initialize_io_schemas(action_args, arg_types, resolved_annotations, input_spec) - if input_schema is not None or output_schema is not None: - self._apply_tool_schema_overrides(input_schema=input_schema, output_schema=output_schema) - def _apply_tool_schema_overrides( + def _override_input_schema( self, - *, - input_schema: type[BaseModel] | dict[str, object] | None, - output_schema: type[BaseModel] | dict[str, object] | None, + input_schema: type[BaseModel] | dict[str, object], ) -> None: - """Replace I/O JSON schemas (and input validation) when explicitly provided. + """Replace input JSON schema (and input validation) when explicitly provided. - Used for tools whose ``metadata_fn`` uses loose annotations (e.g. ``Any``) but - callers supply concrete Pydantic models or JSON Schema dicts. + Used when ``metadata_fn`` is loosely typed but the wire contract should be a + Pydantic model or JSON Schema dict. """ - if input_schema is not None: - in_js = to_json_schema(input_schema) - self.input_schema = in_js - if isinstance(input_schema, dict): - self._input_type = None - else: - self._input_type = cast(TypeAdapter[InputT], TypeAdapter(input_schema)) - - if output_schema is not None: - self.output_schema = to_json_schema(output_schema) + in_js = to_json_schema(input_schema) + self.input_schema = in_js + if isinstance(input_schema, dict): + self._input_type = None + else: + self._input_type = cast(TypeAdapter[InputT], TypeAdapter(input_schema)) @property def kind(self) -> ActionKind: diff --git a/py/packages/genkit/src/genkit/_core/_registry.py b/py/packages/genkit/src/genkit/_core/_registry.py index 1e8a9c2330..9df7c99778 100644 --- a/py/packages/genkit/src/genkit/_core/_registry.py +++ b/py/packages/genkit/src/genkit/_core/_registry.py @@ -41,6 +41,7 @@ ModelResponseChunk, ) from genkit._core._plugin import Plugin +from genkit._core._schema import to_json_schema from genkit._core._typing import ( EmbedRequest, EmbedResponse, @@ -138,7 +139,6 @@ def register_action( metadata: dict[str, object] | None = None, span_metadata: dict[str, SpanAttributeValue] | None = None, *, - input_schema: type[BaseModel] | dict[str, object] | None = None, output_schema: type[BaseModel] | dict[str, object] | None = None, ) -> Action[InputT, OutputT, ChunkT]: """Register a new action with the registry. @@ -155,8 +155,6 @@ def register_action( description: Optional human-readable description of the action. metadata: Optional dictionary of metadata about the action. span_metadata: Optional dictionary of tracing span metadata. - input_schema: Optional explicit input JSON schema (Pydantic model or dict). - When set, overrides schemas inferred from ``metadata_fn``. output_schema: Optional explicit output JSON schema (Pydantic model or dict). Returns: @@ -170,9 +168,9 @@ def register_action( description=description, metadata=metadata, span_metadata=span_metadata, - input_schema=input_schema, - output_schema=output_schema, ) + if output_schema is not None: + action.output_schema = to_json_schema(output_schema) action_typed = cast(Action[InputT, OutputT, ChunkT], action) with self._lock: if kind not in self._entries: diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index 2a6c6a7b44..a7db7f62ea 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -1082,7 +1082,7 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool: # Actually Google GenAI expects type=OBJECT for params usually. if not params: params = genai_types.Schema(type=genai_types.Type.OBJECT, properties={}) - # ACC-560: Gemini rejects scalar/array root schemas. Tool params must be OBJECT with + # Gemini rejects scalar/array root schemas. Tool params must be OBJECT with # properties. LLMs always send {"key": value} — wrap scalar/array in {"value": }. elif self._is_scalar_or_array_root(params): params = genai_types.Schema( diff --git a/py/plugins/google-genai/test/models/googlegenai_gemini_test.py b/py/plugins/google-genai/test/models/googlegenai_gemini_test.py index 5bf5602e59..d0dfa1f6c8 100644 --- a/py/plugins/google-genai/test/models/googlegenai_gemini_test.py +++ b/py/plugins/google-genai/test/models/googlegenai_gemini_test.py @@ -484,7 +484,7 @@ def test_gemini_model__create_tool( def test_gemini_model__create_tool_wraps_scalar_input_schema( gemini_model_instance: GeminiModel, ) -> None: - """ACC-560: Scalar/array root input schemas are wrapped in object for Gemini. + """Scalar/array root input schemas are wrapped in object for Gemini. Gemini rejects tool params with type=STRING etc. LLMs always send {"key": value} — we wrap scalar schemas in {"value": }. From 15ddedd977b7b2883601d35f2437c3ab88460235 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 27 Mar 2026 10:55:25 -0500 Subject: [PATCH 06/15] parallelize tool execution --- .../genkit/src/genkit/_ai/_generate.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index 3ad5c76b2f..564cbef3ef 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -16,6 +16,7 @@ """Generate action.""" +import asyncio import copy import inspect import re @@ -632,32 +633,37 @@ async def resolve_tool_requests( revised_model_message = message.model_copy(deep=True) - has_interrupts = False - response_parts: list[Part] = [] + work: list[tuple[int, Action, ToolRequestPart]] = [] i = 0 for tool_request_part in message.content: if not (isinstance(tool_request_part, Part) and isinstance(tool_request_part.root, ToolRequestPart)): # pyright: ignore[reportUnnecessaryIsInstance] i += 1 continue - # Type is now narrowed: tool_request_part.root is ToolRequestPart tool_req_root = tool_request_part.root tool_request = tool_req_root.tool_request if tool_request.name not in tool_dict: raise RuntimeError(f'failed {tool_request.name} not found') tool = tool_dict[tool_request.name] - tool_response_part, interrupt_part = await _resolve_tool_request(tool, tool_req_root) + work.append((i, tool, tool_req_root)) + i += 1 + + if not work: + return (None, Message(role=Role.TOOL, content=[]), None) + + outs = await asyncio.gather(*[_resolve_tool_request(tool, trp) for _, tool, trp in work]) + has_interrupts = False + response_parts: list[Part] = [] + for (idx, _tool, tool_req_root), (tool_response_part, interrupt_part) in zip(work, outs, strict=True): if tool_response_part: - revised_model_message.content[i] = _to_pending_response(tool_req_root, tool_response_part) + revised_model_message.content[idx] = _to_pending_response(tool_req_root, tool_response_part) response_parts.append(Part(root=tool_response_part)) if interrupt_part: has_interrupts = True - revised_model_message.content[i] = Part(root=interrupt_part) - - i += 1 + revised_model_message.content[idx] = Part(root=interrupt_part) if has_interrupts: return (revised_model_message, None, None) From cccac65a803c2ceb490d1a20395b6332904d3174 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 27 Mar 2026 16:50:04 -0500 Subject: [PATCH 07/15] remove prompts as tools, added tests --- .../genkit/src/genkit/_ai/_generate.py | 18 +- py/packages/genkit/src/genkit/_ai/_prompt.py | 31 - py/packages/genkit/src/genkit/_ai/_tools.py | 29 +- .../genkit/tests/genkit/ai/_tools_test.py | 191 +++++ .../tests/genkit/ai/generate_helpers_test.py | 92 +++ .../ai/generate_interrupt_resume_test.py | 684 ++++++++++++++++++ .../genkit/tests/genkit/ai/generate_test.py | 76 ++ .../genkit/tests/genkit/veneer/veneer_test.py | 28 - 8 files changed, 1062 insertions(+), 87 deletions(-) create mode 100644 py/packages/genkit/tests/genkit/ai/_tools_test.py create mode 100644 py/packages/genkit/tests/genkit/ai/generate_helpers_test.py create mode 100644 py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index 564cbef3ef..acdcdc3f4c 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -63,8 +63,8 @@ def tools_to_action_refs( ) -> list[str] | None: """Normalize tool arguments to registry names for :class:`GenerateActionOptions`. - Each item may be: a tool name (``str``), an :class:`~genkit.Action` (e.g. from prompt - ``as_tool()``), or an async function registered as a tool (same object as ``@ai.tool()``). + Each item may be: a tool name (``str``), a :class:`~genkit.Action` of kind ``TOOL``, + or an async function registered as a tool (same object as ``@ai.tool()``). Use :class:`~collections.abc.Sequence` in call sites instead of ``list``: type checkers treat ``list`` as invariant in the item type, so ``[my_tool_fn]`` may not match @@ -734,16 +734,10 @@ async def _resolve_tool_request( async def resolve_tool(registry: Registry, tool_name: str) -> Action: - """Resolve a tool by name from the registry. - - Tries :class:`~genkit.ActionKind.TOOL` first, then prompt actions registered as - :class:`~genkit.ActionKind.PROMPT` / :class:`~genkit.ActionKind.EXECUTABLE_PROMPT` - (e.g. :meth:`~genkit.ExecutablePrompt.as_tool`). - """ - for kind in (ActionKind.TOOL, ActionKind.PROMPT, ActionKind.EXECUTABLE_PROMPT): - tool = await registry.resolve_action(kind=kind, name=tool_name) - if tool is not None: - return tool + """Resolve a :class:`~genkit.ActionKind.TOOL` action by name from the registry.""" + tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_name) + if tool is not None: + return tool msg = f'Unable to resolve tool {tool_name}' raise ValueError(msg) diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index d07517b01a..4115476353 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -547,37 +547,6 @@ async def render( coerced = _coerce_prompt_opts(opts) return await self._render_impl(input, coerced) - async def as_tool(self) -> Action: - """Expose this prompt as a tool. - - Returns the PROMPT action, which can be used as a tool. - """ - await self._ensure_resolved() - # If we have a direct reference to the action, use it - if self._prompt_action is not None: - return self._prompt_action - - # Otherwise, try to look it up using name/variant/ns - if self._name is None: - raise GenkitError( - status='FAILED_PRECONDITION', - message=( - 'Prompt name not available. This prompt was not created via define_prompt_async() or load_prompt().' - ), - ) - - lookup_key = registry_lookup_key(self._name, self._variant, self._ns) - - action = await self._registry.resolve_action_by_key(lookup_key) - - if action is None or action.kind != ActionKind.PROMPT: - raise GenkitError( - status='NOT_FOUND', - message=f'PROMPT action not found for prompt "{self._name}"', - ) - - return action - def register_prompt_actions( registry: Registry, diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index f72c091e3f..d490e0f3a6 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -183,16 +183,15 @@ def restart_interrupted_tool( restart = getattr(tool, 'restart', None) if restart is None: raise TypeError(f'{tool!r} has no restart method; pass a tool from define_tool or Genkit.tool') - part = restart( + out = restart( replace_input, interrupt=interrupt, resumed_metadata=resumed_metadata, ) - root = part.root - if not isinstance(root, ToolRequestPart): - msg = f'Expected tool.restart() to return Part(root=ToolRequestPart), got {type(root)!r}' + if not isinstance(out, ToolRequestPart): + msg = f'Expected tool.restart() to return ToolRequestPart, got {type(out)!r}' raise TypeError(msg) - return root + return out async def run_tool_after_restart(tool: Action[Any, Any, Any], restart_trp: ToolRequestPart) -> Part: @@ -324,7 +323,7 @@ def restart( *, interrupt: ToolRequestPart, resumed_metadata: dict[str, Any] | None = None, - ) -> Part: + ) -> ToolRequestPart: """Create a restart request for an interrupted tool call. Args: @@ -334,7 +333,7 @@ def restart( resumed_metadata: Passed to the tool as :attr:`ToolRunContext.resumed_metadata`. Returns: - ``Part`` wrapping a ``ToolRequestPart`` for ``resume_restart`` / message history. + A ``ToolRequestPart`` for ``resume_restart`` / message history. Example: ``pay_invoice.restart({**trp.tool_request.input, "confirmed": True}, interrupt=trp,`` @@ -360,15 +359,13 @@ def restart( if 'interrupt' in new_meta: del new_meta['interrupt'] - return Part( - root=ToolRequestPart( - tool_request=ToolRequest( - name=tool_req.name, - ref=tool_req.ref, - input=new_input, - ), - metadata=new_meta, - ) + return ToolRequestPart( + tool_request=ToolRequest( + name=tool_req.name, + ref=tool_req.ref, + input=new_input, + ), + metadata=new_meta, ) wrapper.restart = restart # type: ignore[attr-defined] diff --git a/py/packages/genkit/tests/genkit/ai/_tools_test.py b/py/packages/genkit/tests/genkit/ai/_tools_test.py new file mode 100644 index 0000000000..818874e9df --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/_tools_test.py @@ -0,0 +1,191 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for tool restart builder and run_tool_after_restart.""" + +import pytest + +from genkit import ActionKind, Genkit +from genkit._ai._generate import run_tool_after_restart +from genkit._ai._tools import ( + Interrupt, + ToolRunContext, + _tool_original_input, + _tool_resumed_metadata, +) +from genkit._core._error import GenkitError +from genkit._core._typing import ToolRequest, ToolRequestPart + + +@pytest.mark.asyncio +async def test_restart_sets_resumed_metadata_and_strips_interrupt() -> None: + """``tool.restart`` → TRP.metadata.resumed plus ``interrupt`` removed; input unchanged when replace is None.""" + ai = Genkit() + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: # noqa: ARG001 + return 'ok' + + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='pay', ref='r1', input={'amount': 10}), + metadata={'interrupt': {'reason': 'hold'}}, + ) + out = pay.restart(None, interrupt=interrupt_trp, resumed_metadata={'k': 'v'}) + assert isinstance(out, ToolRequestPart) + assert out.metadata is not None + assert out.metadata.get('resumed') == {'k': 'v'} + assert 'interrupt' not in out.metadata + assert out.tool_request.input == {'amount': 10} + + +@pytest.mark.asyncio +async def test_restart_replace_input_sets_replaced_input() -> None: + """Restart with new input sets ``replacedInput`` to prior input and updates ``tool_request.input``.""" + ai = Genkit() + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: # noqa: ARG001 + return 'ok' + + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='pay', ref='r1', input={'amount': 10}), + metadata={'interrupt': True}, + ) + out = pay.restart({'amount': 99}, interrupt=interrupt_trp, resumed_metadata={'by': 'u'}) + assert isinstance(out, ToolRequestPart) + assert out.metadata is not None + assert out.metadata.get('replacedInput') == {'amount': 10} + assert out.tool_request.input == {'amount': 99} + assert out.metadata.get('resumed') == {'by': 'u'} + + +@pytest.mark.asyncio +async def test_restart_resumed_defaults_to_true() -> None: + """When ``resumed_metadata=None``, restart TRP sets ``metadata.resumed`` to True.""" + ai = Genkit() + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: # noqa: ARG001 + return 'ok' + + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='pay', ref='r1', input={}), + metadata={'interrupt': True}, + ) + out = pay.restart(None, interrupt=interrupt_trp, resumed_metadata=None) + assert isinstance(out, ToolRequestPart) + assert out.metadata is not None + assert out.metadata.get('resumed') is True + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_resumed_true_maps_to_empty_dict_in_context() -> None: + """``run_tool_after_restart``: ``metadata.resumed is True`` → ``ToolRunContext.resumed_metadata`` is ``{}``.""" + ai = Genkit() + captured: list[tuple[dict | None, object | None]] = [] + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + captured.append((ctx.resumed_metadata, ctx.original_input)) + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={'q': 1}), + metadata={'resumed': True}, + ) + await run_tool_after_restart(action, restart_trp) + assert len(captured) == 1 + assert captured[0][0] == {} + assert captured[0][1] is None + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_resumed_dict() -> None: + """Restart TRP with ``metadata.resumed`` dict is passed through to ``ToolRunContext.resumed_metadata``.""" + ai = Genkit() + captured: list[dict | None] = [] + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + captured.append(ctx.resumed_metadata) + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={}), + metadata={'resumed': {'by': 'x'}}, + ) + await run_tool_after_restart(action, restart_trp) + assert captured == [{'by': 'x'}] + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_replaced_input() -> None: + """``replacedInput`` on TRP sets tool input from current request and ``original_input`` from prior.""" + ai = Genkit() + captured: list[tuple[object, object | None]] = [] + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + captured.append((inp, ctx.original_input)) + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={'new': True}), + metadata={'resumed': True, 'replacedInput': {'old': True}}, + ) + await run_tool_after_restart(action, restart_trp) + assert len(captured) == 1 + assert captured[0][0] == {'new': True} + assert captured[0][1] == {'old': True} + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_resets_contextvars() -> None: + """After ``run_tool_after_restart`` returns, resume ContextVars are cleared (no leak between runs).""" + ai = Genkit() + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={}), + metadata={'resumed': True}, + ) + await run_tool_after_restart(action, restart_trp) + assert _tool_resumed_metadata.get() is None + assert _tool_original_input.get() is None + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_nested_interrupt_raises() -> None: + """Tool raising ``Interrupt`` during a restart run raises ``GenkitError`` (nested interrupt unsupported).""" + ai = Genkit() + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + raise Interrupt() + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={}), + metadata={'resumed': True}, + ) + with pytest.raises(GenkitError) as ei: + await run_tool_after_restart(action, restart_trp) + msg = ei.value.original_message.lower() + assert 'restart' in msg or 'nested' in msg or 'not supported' in msg diff --git a/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py b/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py new file mode 100644 index 0000000000..54f9cf4203 --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py @@ -0,0 +1,92 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for private helpers in genkit._ai._generate (interrupt / resume).""" + +import pytest + +from genkit._ai._generate import ( + _find_corresponding_restart, + _find_corresponding_tool_response, + _interrupt_from_tool_exc, + _to_pending_response, +) +from genkit._ai._tools import Interrupt +from genkit._core._error import GenkitError +from genkit._core._typing import Part, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart + + +def test_find_corresponding_restart_matches_name_and_ref() -> None: + """``_find_corresponding_restart`` picks the resume TRP whose name+ref match the pending TRP; else None.""" + pending = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r1', input={}), + ) + match = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r1', input={'new': True}), + metadata={'resumed': True}, + ) + other_ref = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r2', input={}), + metadata={'resumed': True}, + ) + other_name = ToolRequestPart( + tool_request=ToolRequest(name='u', ref='r1', input={}), + metadata={'resumed': True}, + ) + + assert _find_corresponding_restart([match], pending) is match + assert _find_corresponding_restart([other_ref, match], pending) is match + assert _find_corresponding_restart([other_ref], pending) is None + assert _find_corresponding_restart([other_name], pending) is None + assert _find_corresponding_restart(None, pending) is None + assert _find_corresponding_restart([], pending) is None + + +def test_find_corresponding_tool_response_matches_name_and_ref() -> None: + """``_find_corresponding_tool_response`` matches ``ToolResponsePart`` to pending TRP by name+ref.""" + pending = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r1', input={}), + ) + trp = ToolResponsePart( + tool_response=ToolResponse(name='t', ref='r1', output=42), + ) + other = ToolResponsePart( + tool_response=ToolResponse(name='t', ref='r2', output=0), + ) + + got = _find_corresponding_tool_response([trp], pending) + assert got is not None + assert got.root == trp + + assert _find_corresponding_tool_response([other], pending) is None + assert _find_corresponding_tool_response([], pending) is None + + +def test_interrupt_from_tool_exc() -> None: + """``_interrupt_from_tool_exc`` unwraps bare ``Interrupt`` or ``GenkitError.cause``; else None.""" + intr = Interrupt({'x': 1}) + assert _interrupt_from_tool_exc(intr) is intr + + wrapped = GenkitError(message='x', cause=intr) + assert _interrupt_from_tool_exc(wrapped) is intr + + assert _interrupt_from_tool_exc(ValueError('x')) is None + + +def test_to_pending_response_sets_pending_output() -> None: + """``_to_pending_response`` merges prior TRP metadata with ``pendingOutput`` from the tool response.""" + req = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r1', input={'a': 1}), + metadata={'interrupt': {'old': True}}, + ) + resp = ToolResponsePart( + tool_response=ToolResponse(name='t', ref='r1', output={'out': 2}), + ) + part = _to_pending_response(req, resp) + root = part.root + assert isinstance(root, ToolRequestPart) + assert root.tool_request.name == 't' + assert root.tool_request.ref == 'r1' + assert root.metadata is not None + assert root.metadata.get('pendingOutput') == {'out': 2} + assert root.metadata.get('interrupt') == {'old': True} diff --git a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py new file mode 100644 index 0000000000..5fc3f171ab --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py @@ -0,0 +1,684 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for interrupt, resume, and restart behavior in ``generate_action``.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from genkit import Genkit, Message, ModelResponse +from genkit._ai._generate import _resolve_resumed_tool_request, generate_action +from genkit._ai._tools import Interrupt, ToolRunContext, respond_to_interrupt +from genkit._ai._testing import define_programmable_model +from genkit._core._error import GenkitError +from genkit._core._model import GenerateActionOptions +from genkit._core._typing import FinishReason, Part, Resume + + +def _wire(messages: list[Message]) -> list[dict[str, Any]]: + """Messages as JSON-shaped dicts (``model_dump`` with aliases) for comparing to expected wire.""" + return [m.model_dump(mode='json', exclude_none=True, by_alias=True) for m in messages] + + +def _gen_opts(ai: Genkit, *, tools: list[str], messages: list[Message], resume: Resume | None = None) -> GenerateActionOptions: + return GenerateActionOptions( + model='programmableModel', + messages=messages, + tools=tools, + resume=resume, + ) + + +@pytest.mark.asyncio +async def test_normal_two_arg_tools_see_no_resume_context() -> None: + """Two tools in one batch, no interrupt: ``ToolRunContext`` should stay empty of resume fields. + + The model asks for both tools in one turn; each tool records whether it thinks it's a resume + (it shouldn't), and we compare the whole conversation to the expected wire. + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + seen: list[tuple[bool, object | None, object | None]] = [] + + @ai.tool(name='u1') + async def u1(_: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + seen.append((ctx.is_resumed(), ctx.resumed_metadata, ctx.original_input)) + return 'a' + + @ai.tool(name='u2') + async def u2(_: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + seen.append((ctx.is_resumed(), ctx.resumed_metadata, ctx.original_input)) + return 'b' + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate( + { + 'role': 'model', + 'content': [ + {'text': 'go'}, + {'toolRequest': {'ref': '1', 'name': 'u1', 'input': {}}}, + {'toolRequest': {'ref': '2', 'name': 'u2', 'input': {}}}, + ], + } + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'done'}]}), + ) + ) + + r = await generate_action( + ai.registry, + _gen_opts(ai, tools=['u1', 'u2'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + ) + assert seen == [(False, None, None), (False, None, None)] + assert _wire(r.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'go'}, + {'toolRequest': {'ref': '1', 'name': 'u1', 'input': {}}}, + {'toolRequest': {'ref': '2', 'name': 'u2', 'input': {}}}, + ], + }, + { + 'role': 'tool', + 'content': [ + {'toolResponse': {'ref': '1', 'name': 'u1', 'output': 'a'}}, + {'toolResponse': {'ref': '2', 'name': 'u2', 'output': 'b'}}, + ], + }, + { + 'role': 'model', + 'content': [{'text': 'done'}], + }, + ] + + +@pytest.mark.asyncio +async def test_interrupt_wires_trp_metadata_interrupt_and_stops() -> None: + """When the tool raises ``Interrupt``, the interrupt payload lands on the TRP metadata, finish is + ``INTERRUPTED``, and we never get a ``role=tool`` row yet—only user + model in the history. + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + + @ai.tool(name='intr') + async def intr(_: dict) -> str: # noqa: ARG001 + raise Interrupt({'reason': 'x'}) + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate( + { + 'role': 'model', + 'content': [ + {'text': 'call'}, + {'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}}, + ], + } + ), + ) + ) + + r = await generate_action( + ai.registry, + _gen_opts(ai, tools=['intr'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + ) + assert r.finish_reason == FinishReason.INTERRUPTED + assert _wire(r.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'call'}, + { + 'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}, + 'metadata': {'interrupt': {'reason': 'x'}}, + }, + ], + }, + ] + + +@pytest.mark.asyncio +async def test_resume_respond_trp_gets_resolved_interrupt_and_tool_trp() -> None: + """Follow-up generate with ``Resume(respond=[...])``: the stuck TRP picks up ``resolvedInterrupt``, + the tool reply shows up under ``interruptResponse``, and the model can answer again. Compares + wire before and after the interrupt. + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + + @ai.tool(name='intr') + async def intr(_: dict) -> str: # noqa: ARG001 + raise Interrupt({'reason': 'x'}) + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate( + { + 'role': 'model', + 'content': [ + {'text': 'call'}, + {'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}}, + ], + } + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'after resume'}]}), + ) + ) + + first = await generate_action( + ai.registry, + _gen_opts(ai, tools=['intr'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + ) + assert first.finish_reason == FinishReason.INTERRUPTED + assert _wire(first.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'call'}, + { + 'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}, + 'metadata': {'interrupt': {'reason': 'x'}}, + }, + ], + }, + ] + + reply = respond_to_interrupt({'bar': 2}, interrupt=first.interrupts[0]) + + second = await generate_action( + ai.registry, + _gen_opts(ai, tools=['intr'], messages=list(first.messages), resume=Resume(respond=[reply])), + ) + + assert second.finish_reason == FinishReason.STOP + assert _wire(second.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'call'}, + { + 'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'reason': 'x'}}, + }, + ], + }, + { + 'role': 'tool', + 'content': [ + { + 'toolResponse': {'ref': 'r1', 'name': 'intr', 'output': {'bar': 2}}, + 'metadata': {'interruptResponse': True}, + }, + ], + 'metadata': {'resumed': True}, + }, + { + 'role': 'model', + 'content': [{'text': 'after resume'}], + }, + ] + + +@pytest.mark.asyncio +async def test_tool_either_interrupts_or_returns() -> None: + """Same ``bank_transfer`` tool: ``preapproved=False`` interrupts (short history), + ``preapproved=True`` runs through to a normal tool response and a final model turn. Two + back-to-back runs with different queued model responses, each wire asserted in full. + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + + @ai.tool(name='bank_transfer') + async def bank_transfer(inp: dict) -> int: + if not inp.get('preapproved'): + raise Interrupt({'reason': 'awaiting_approval'}) + return 42 + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate( + { + 'role': 'model', + 'content': [ + {'text': 't'}, + { + 'toolRequest': { + 'ref': 'g', + 'name': 'bank_transfer', + 'input': {'preapproved': False}, + }, + }, + ], + } + ), + ) + ) + r_fail = await generate_action( + ai.registry, + _gen_opts( + ai, + tools=['bank_transfer'], + messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})], + ), + ) + assert r_fail.finish_reason == FinishReason.INTERRUPTED + assert _wire(r_fail.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 't'}, + { + 'toolRequest': { + 'ref': 'g', + 'name': 'bank_transfer', + 'input': {'preapproved': False}, + }, + 'metadata': {'interrupt': {'reason': 'awaiting_approval'}}, + }, + ], + }, + ] + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate( + { + 'role': 'model', + 'content': [ + {'text': 't2'}, + { + 'toolRequest': { + 'ref': 'g2', + 'name': 'bank_transfer', + 'input': {'preapproved': True}, + }, + }, + ], + } + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'ok'}]}), + ) + ) + r_ok = await generate_action( + ai.registry, + _gen_opts( + ai, + tools=['bank_transfer'], + messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})], + ), + ) + assert r_ok.finish_reason == FinishReason.STOP + assert _wire(r_ok.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 't2'}, + { + 'toolRequest': { + 'ref': 'g2', + 'name': 'bank_transfer', + 'input': {'preapproved': True}, + }, + }, + ], + }, + { + 'role': 'tool', + 'content': [ + {'toolResponse': {'ref': 'g2', 'name': 'bank_transfer', 'output': 42}}, + ], + }, + { + 'role': 'model', + 'content': [{'text': 'ok'}], + }, + ] + + +@pytest.mark.asyncio +async def test_resume_restart_runs_tool_second_time_and_resolved_interrupt_on_model() -> None: + """``Resume(restart=[...])`` reruns the tool with new input after an interrupt. The tool runs + twice (tracked in ``calls``); the second pass shows ``resolvedInterrupt`` on the model TRP and a + plain ``toolResponse`` (no ``interruptResponse`` on that path). + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + calls: list[str] = [] + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: + calls.append('run') + if not inp.get('ok'): + raise Interrupt({'hold': True}) + return 'paid' + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate( + { + 'role': 'model', + 'content': [ + {'text': 'x'}, + {'toolRequest': {'ref': 'p1', 'name': 'pay', 'input': {}}}, + ], + } + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'final'}]}), + ) + ) + + first = await generate_action( + ai.registry, + _gen_opts(ai, tools=['pay'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + ) + assert first.finish_reason == FinishReason.INTERRUPTED + assert _wire(first.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'x'}, + { + 'toolRequest': {'ref': 'p1', 'name': 'pay', 'input': {}}, + 'metadata': {'interrupt': {'hold': True}}, + }, + ], + }, + ] + + restart_trp = pay.restart({'ok': True}, interrupt=first.interrupts[0], resumed_metadata={'by': 'test'}) + + second = await generate_action( + ai.registry, + _gen_opts(ai, tools=['pay'], messages=list(first.messages), resume=Resume(restart=[restart_trp])), + ) + + assert second.finish_reason == FinishReason.STOP + assert calls == ['run', 'run'] + assert _wire(second.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'x'}, + { + 'toolRequest': {'ref': 'p1', 'name': 'pay', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'hold': True}}, + }, + ], + }, + { + 'role': 'tool', + 'content': [ + {'toolResponse': {'ref': 'p1', 'name': 'pay', 'output': 'paid'}}, + ], + 'metadata': {'resumed': True}, + }, + { + 'role': 'model', + 'content': [{'text': 'final'}], + }, + ] + + +@pytest.mark.asyncio +async def test_mixed_resume_one_respond_one_restart() -> None: + """Two tool calls both interrupt in one turn; the next generate fills in a ``respond`` for one + ref and a ``restart`` for the other. Expect one tool message with two parts (respond path still + has ``interruptResponse``; restart path does not). + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + + @ai.tool(name='a') + async def a_tool(_: dict) -> str: # noqa: ARG001 + raise Interrupt({'tool': 'a'}) + + @ai.tool(name='b') + async def b_tool(inp: dict) -> str: + if inp.get('ok'): + return 'b-done' + raise Interrupt({'tool': 'b'}) + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate( + { + 'role': 'model', + 'content': [ + {'text': 'both'}, + {'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}}, + {'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}}, + ], + } + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'end'}]}), + ) + ) + + first = await generate_action( + ai.registry, + _gen_opts(ai, tools=['a', 'b'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + ) + assert first.finish_reason == FinishReason.INTERRUPTED + assert _wire(first.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'both'}, + { + 'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}, + 'metadata': {'interrupt': {'tool': 'a'}}, + }, + { + 'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}, + 'metadata': {'interrupt': {'tool': 'b'}}, + }, + ], + }, + ] + + ia = next(p for p in first.interrupts if p.tool_request.name == 'a') + ib = next(p for p in first.interrupts if p.tool_request.name == 'b') + + second = await generate_action( + ai.registry, + _gen_opts( + ai, + tools=['a', 'b'], + messages=list(first.messages), + resume=Resume( + respond=[respond_to_interrupt({'done': True}, interrupt=ia)], + restart=[b_tool.restart({'ok': True}, interrupt=ib, resumed_metadata=None)], + ), + ), + ) + + assert second.finish_reason == FinishReason.STOP + assert _wire(second.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'both'}, + { + 'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'tool': 'a'}}, + }, + { + 'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'tool': 'b'}}, + }, + ], + }, + { + 'role': 'tool', + 'content': [ + { + 'toolResponse': {'ref': 'ra', 'name': 'a', 'output': {'done': True}}, + 'metadata': {'interruptResponse': True}, + }, + {'toolResponse': {'ref': 'rb', 'name': 'b', 'output': 'b-done'}}, + ], + 'metadata': {'resumed': True}, + }, + { + 'role': 'model', + 'content': [{'text': 'end'}], + }, + ] + + +@pytest.mark.asyncio +async def test_pending_output_trp_yields_tool_response_with_source_pending() -> None: + """If the TRP only has ``pendingOutput``, ``_resolve_resumed_tool_request`` should synthesize a + tool response with that output and ``metadata.source`` set to ``pending``—checked on the part + directly, not a full message list. + """ + ai = Genkit() + part = Part.model_validate( + { + 'toolRequest': {'name': 't', 'ref': 'r', 'input': {}}, + 'metadata': {'pendingOutput': 123}, + } + ) + _, resp = await _resolve_resumed_tool_request( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'x'}]})], + ), + part, + ) + assert resp.model_dump(mode='json', exclude_none=True, by_alias=True) == { + 'toolResponse': {'ref': 'r', 'name': 't', 'output': 123}, + 'metadata': {'source': 'pending'}, + } + + +@pytest.mark.asyncio +async def test_resume_without_matching_replies_raises() -> None: + """Hand-built history with an interrupted TRP but an empty ``Resume()``: expect ``GenkitError`` + and a message that mentions replies or restarts. + """ + ai = Genkit() + _, _ = define_programmable_model(ai) + + with pytest.raises(GenkitError) as ei: + await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[ + Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]}), + Message.model_validate( + { + 'role': 'model', + 'content': [ + { + 'toolRequest': {'ref': 'z', 'name': 'missing', 'input': {}}, + 'metadata': {'interrupt': True}, + }, + ], + } + ), + ], + resume=Resume(), + ), + ) + assert 'replies' in ei.value.original_message or 'restarts' in ei.value.original_message + + +@pytest.mark.asyncio +async def test_resume_requires_last_message_model_with_tool_requests() -> None: + """Can't resume when the transcript ends on a user turn: ``GenkitError``, and the message should + mention needing a model message. + """ + ai = Genkit() + _, _ = define_programmable_model(ai) + + with pytest.raises(GenkitError) as ei: + await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'only user'}]})], + resume=Resume(), + ), + ) + assert 'model' in ei.value.original_message.lower() + diff --git a/py/packages/genkit/tests/genkit/ai/generate_test.py b/py/packages/genkit/tests/genkit/ai/generate_test.py index 0a550c9093..f24da35c0e 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_test.py @@ -16,6 +16,7 @@ from genkit import ActionKind, Document, Genkit, Message, ModelResponse, ModelResponseChunk from genkit._ai._generate import generate_action +from genkit._ai._tools import Interrupt from genkit._ai._model import text_from_content, text_from_message from genkit._ai._testing import ( ProgrammableModel, @@ -30,6 +31,8 @@ Part, Role, TextPart, + ToolRequest, + ToolRequestPart, ) @@ -370,6 +373,79 @@ def collect_chunks(c: ModelResponseChunk) -> None: ] +@pytest.mark.asyncio +async def test_parallel_tool_requests_one_interrupt_keeps_pending_output_for_others( + setup_test: tuple[Genkit, ProgrammableModel], +) -> None: + """With asyncio.gather in resolve_tool_requests: one interrupt still records pendingOutput for others.""" + ai, pm = setup_test + + @ai.tool(name='tool_a') + async def tool_a() -> str: + return 'a_ok' + + @ai.tool(name='tool_b') + async def tool_b() -> None: + raise Interrupt({'stop': True}) + + @ai.tool(name='tool_c') + async def tool_c() -> str: + return 'c_ok' + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message( + role=Role.MODEL, + content=[ + Part(TextPart(text='call three')), + Part( + root=ToolRequestPart( + tool_request=ToolRequest(name='tool_a', ref='ref-a', input={}), + ) + ), + Part( + root=ToolRequestPart( + tool_request=ToolRequest(name='tool_b', ref='ref-b', input={}), + ) + ), + Part( + root=ToolRequestPart( + tool_request=ToolRequest(name='tool_c', ref='ref-c', input={}), + ) + ), + ], + ), + ) + ) + + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[ + Message(role=Role.USER, content=[Part(TextPart(text='hi'))]), + ], + tools=['tool_a', 'tool_b', 'tool_c'], + ), + ) + + assert response.finish_reason == FinishReason.INTERRUPTED + assert response.message is not None + parts = response.message.content + assert len(parts) == 4 + assert parts[0].root == TextPart(text='call three') + a_root = parts[1].root + b_root = parts[2].root + c_root = parts[3].root + assert isinstance(a_root, ToolRequestPart) + assert isinstance(b_root, ToolRequestPart) + assert isinstance(c_root, ToolRequestPart) + assert a_root.metadata and a_root.metadata.get('pendingOutput') == 'a_ok' + assert b_root.metadata and b_root.metadata.get('interrupt') == {'stop': True} + assert c_root.metadata and c_root.metadata.get('pendingOutput') == 'c_ok' + + ########################################################################## # run tests from /tests/specs/generate.yaml ########################################################################## diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py index 21363b76b6..307fb36c0d 100644 --- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py @@ -316,34 +316,6 @@ async def test_generate_with_system_prompt_messages( assert (await stream_result.response).text == want_txt -@pytest.mark.asyncio -async def test_generate_with_prompt_as_tool(setup_test: SetupFixture) -> None: - """Test that ai.generate(tools=[prompt.as_tool()]) works (ACC-557). - - Previously, tools param only accepted list[str]; prompt.as_tool() returns Action, - causing Pydantic ValidationError. This test verifies Action is now accepted. - """ - ai, echo, *_ = setup_test - - sub_prompt = ai.define_prompt(name='subPrompt', prompt='Sub prompt') - prompt_action = await sub_prompt.as_tool() - - # Should NOT raise ValidationError - Action in tools is now accepted - response = await ai.generate( - model='echoModel', - prompt='Use the sub prompt', - tools=[prompt_action], - tool_choice=ToolChoice.REQUIRED, - ) - - assert response.text is not None - assert echo.last_request is not None - tools = echo.last_request.tools - assert tools is not None - assert len(tools) == 1 - assert tools[0].name == 'subPrompt' - - @pytest.mark.asyncio async def test_generate_with_tools(setup_test: SetupFixture) -> None: """Test that the generate function with tools works.""" From 81e2f08553f3feefb0b4a1926912a58b0c195577 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 27 Mar 2026 21:01:02 -0500 Subject: [PATCH 08/15] clean up Tool interface --- py/packages/genkit/src/genkit/__init__.py | 4 +- py/packages/genkit/src/genkit/_ai/_aio.py | 38 ++-- .../genkit/src/genkit/_ai/_generate.py | 22 +- py/packages/genkit/src/genkit/_ai/_prompt.py | 9 +- py/packages/genkit/src/genkit/_ai/_tools.py | 200 ++++++++---------- py/packages/genkit/src/genkit/_core/_flow.py | 4 +- .../tests/genkit/ai/generate_helpers_test.py | 4 +- .../ai/generate_interrupt_resume_test.py | 181 ++++++++-------- .../genkit/tests/genkit/ai/generate_test.py | 2 +- .../tool-interrupts/src/approval_example.py | 8 +- .../tool-interrupts/src/respond_example.py | 2 +- 11 files changed, 206 insertions(+), 268 deletions(-) diff --git a/py/packages/genkit/src/genkit/__init__.py b/py/packages/genkit/src/genkit/__init__.py index 4b756ed117..0a4c5ab28d 100644 --- a/py/packages/genkit/src/genkit/__init__.py +++ b/py/packages/genkit/src/genkit/__init__.py @@ -24,9 +24,9 @@ ) from genkit._ai._tools import ( Interrupt, + Tool, ToolRunContext, respond_to_interrupt, - restart_interrupted_tool, ) from genkit._core._action import Action, StreamResponse from genkit._core._error import GenkitError, PublicError @@ -98,8 +98,8 @@ 'GenkitError', 'PublicError', 'Interrupt', + 'Tool', 'respond_to_interrupt', - 'restart_interrupted_tool', # Content types 'Constrained', 'CustomPart', diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index 831ef4bc84..1fc8844306 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -27,7 +27,7 @@ import uuid from collections.abc import Awaitable, Callable, Coroutine, Sequence from pathlib import Path -from typing import Any, ParamSpec, TypeVar, cast, overload +from typing import Any, TypeVar, cast, overload import anyio import uvicorn @@ -71,7 +71,7 @@ ResourceOptions, define_resource, ) -from genkit._ai._tools import define_interrupt, define_tool +from genkit._ai._tools import Tool, define_interrupt, define_tool from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._background import ( BackgroundAction, @@ -120,7 +120,7 @@ InputT = TypeVar('InputT') OutputT = TypeVar('OutputT') ChunkT = TypeVar('ChunkT') -P = ParamSpec('P') + R = TypeVar('R') T = TypeVar('T') @@ -262,12 +262,10 @@ def define_dynamic_action_provider( metadata=metadata, ) - def tool( - self, name: str | None = None, description: str | None = None - ) -> Callable[[Callable[P, T]], Callable[P, T]]: + def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable[..., Any]], Tool]: """Decorator to register a function as a tool.""" - def wrapper(func: Callable[P, T]) -> Callable[P, T]: + def wrapper(func: Callable[..., Any]) -> Tool: return define_tool(self.registry, func, name, description) return wrapper @@ -278,7 +276,7 @@ def define_interrupt( *, input_schema: type[BaseModel] | dict[str, object] | None = None, description: str | None = None, - ) -> Callable[..., Any]: + ) -> Tool: """Register an interrupt tool that always pauses for user input. Args: @@ -426,7 +424,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -454,7 +452,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -482,7 +480,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -510,7 +508,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -536,7 +534,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -774,7 +772,7 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -801,7 +799,7 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -826,7 +824,7 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -882,7 +880,7 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -910,7 +908,7 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -936,7 +934,7 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, @@ -1189,7 +1187,7 @@ async def generate_operation( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, config: dict[str, object] | ModelConfig | None = None, diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index acdcdc3f4c..4e10161f49 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -20,7 +20,7 @@ import copy import inspect import re -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Callable, Sequence from typing import Any, cast from pydantic import BaseModel @@ -36,7 +36,7 @@ ModelResponseChunk, ) from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources -from genkit._ai._tools import Interrupt, run_tool_after_restart, unwrap_wrapped_scalar_tool_input_if_needed +from genkit._ai._tools import Interrupt, Tool, run_tool_after_restart, unwrap_wrapped_scalar_tool_input_if_needed from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._error import GenkitError from genkit._core._logger import get_logger @@ -59,16 +59,12 @@ def tools_to_action_refs( - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None, + tools: Sequence[str | Tool] | None, ) -> list[str] | None: """Normalize tool arguments to registry names for :class:`GenerateActionOptions`. - Each item may be: a tool name (``str``), a :class:`~genkit.Action` of kind ``TOOL``, - or an async function registered as a tool (same object as ``@ai.tool()``). - - Use :class:`~collections.abc.Sequence` in call sites instead of ``list``: type checkers - treat ``list`` as invariant in the item type, so ``[my_tool_fn]`` may not match - ``list[str | Action | ...]`` even though it works at runtime. + Each item may be a tool name (``str``) or a :class:`~genkit.Tool` returned by + :meth:`~genkit.Genkit.tool`. """ if tools is None: return None @@ -76,14 +72,8 @@ def tools_to_action_refs( for t in tools: if isinstance(t, str): refs.append(t) - elif isinstance(t, Action): - refs.append(t.name) else: - name = getattr(t, '__name__', None) - if not isinstance(name, str) or not name: - msg = f'Cannot resolve tool name from callable: {t!r}' - raise TypeError(msg) - refs.append(name) + refs.append(t.name) return refs diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index 4115476353..4d453831b0 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -47,6 +47,7 @@ ModelResponse, ModelResponseChunk, ) +from genkit._ai._tools import Tool from genkit._core._action import Action, ActionKind, ActionRunContext, StreamingCallback, create_action_key from genkit._core._channel import Channel from genkit._core._error import GenkitError @@ -123,7 +124,7 @@ class PromptGenerateOptions(TypedDict, total=False): config: dict[str, Any] | ModelConfig | None messages: list[Message] | None docs: list[Document] | None - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None + tools: Sequence[str | Tool] | None resources: list[str] | None tool_choice: ToolChoice | None output: OutputOptions | None @@ -231,7 +232,7 @@ class PromptConfig(BaseModel): max_turns: int | None = None return_tool_requests: bool | None = None metadata: dict[str, Any] | None = None - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None + tools: Sequence[str | Tool] | None = None tool_choice: ToolChoice | None = None use: list[ModelMiddleware] | None = None docs: list[Document] | None = None @@ -263,7 +264,7 @@ def __init__( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, Any] | None = None, - tools: Sequence[str | Action[Any, Any, Any] | Callable[..., Awaitable[Any]]] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -674,7 +675,7 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig) resume_metadata=options.resume_metadata, ) - # Convert tool refs (name, Action, or @tool callable) to string names for GenerateActionOptions + # Convert tool refs (str name or Tool object) to string names for GenerateActionOptions tools_refs = tools_to_action_refs(options.tools) return GenerateActionOptions( diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index d490e0f3a6..16ca237230 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -19,18 +19,77 @@ import inspect from collections.abc import Callable from contextvars import ContextVar -from functools import wraps -from typing import Any, ParamSpec, TypeVar, cast +from typing import Any, cast -from pydantic import BaseModel +from pydantic import BaseModel, PrivateAttr from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._error import GenkitError from genkit._core._registry import Registry -from genkit._core._typing import Part, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart +from genkit._core._typing import Part, ToolDefinition, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart + + +class Tool(ToolDefinition): + """A registered tool: its :class:`ToolDefinition` wire format plus a callable and ``restart``.""" + + _action: Any = PrivateAttr() + + def __init__(self, *, action: Any, **data: Any) -> None: # noqa: ANN401 + super().__init__(**data) + self._action = action + + async def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + """Call the underlying action.""" + return (await self._action.run(*args, **kwargs)).response + + def restart( + self, + replace_input: Any | None = None, # noqa: ANN401 + *, + interrupt: ToolRequestPart, + resumed_metadata: dict[str, Any] | None = None, + ) -> ToolRequestPart: + """Create a restart request for an interrupted tool call. + + Args: + replace_input: Optional new ``tool_request.input`` for this run (previous input is + stored in ``metadata.replacedInput`` when this is set). + interrupt: The interrupted ``ToolRequestPart`` (e.g. from ``response.interrupts``). + resumed_metadata: Passed to the tool as :attr:`ToolRunContext.resumed_metadata`. + + Returns: + A ``ToolRequestPart`` for ``resume_restart`` / message history. + + Example: + ``pay_invoice.restart({**trp.tool_request.input, "confirmed": True}, interrupt=trp,`` + ``resumed_metadata={"by": "bob"})`` + """ + tool_req = interrupt.tool_request + if tool_req.name != self.name: + raise ValueError(f"Interrupt is for tool '{tool_req.name}', not '{self.name}'") + + existing_meta = interrupt.metadata or {} + new_meta: dict[str, Any] = dict(existing_meta) if existing_meta else {} + + new_meta['resumed'] = resumed_metadata if resumed_metadata is not None else True + + new_input = tool_req.input + if replace_input is not None: + new_meta['replacedInput'] = tool_req.input + new_input = replace_input + + if 'interrupt' in new_meta: + del new_meta['interrupt'] + + return ToolRequestPart( + tool_request=ToolRequest( + name=tool_req.name, + ref=tool_req.ref, + input=new_input, + ), + metadata=new_meta, + ) -P = ParamSpec('P') -T = TypeVar('T') # Context variables for propagating resumed metadata to tools _tool_resumed_metadata: ContextVar[dict[str, Any] | None] = ContextVar('tool_resumed_metadata', default=None) @@ -106,8 +165,8 @@ class Interrupt(Exception): # noqa: N818 - public Genkit name; not renamed *Err with ``cause=Interrupt``; generation attaches interrupt metadata to the pending tool request. - To resume, use :func:`respond_to_interrupt` or :func:`restart_interrupted_tool` (or - ``tool.restart`` on the registered tool callable). + To resume, use :func:`respond_to_interrupt` or ``tool.restart(...)`` on the + registered :class:`Tool`. """ def __init__(self, data: dict[str, Any] | None = None) -> None: @@ -157,43 +216,6 @@ def respond_to_interrupt( return _tool_response_part(interrupt, response, metadata) -def restart_interrupted_tool( - replace_input: Any | None = None, # noqa: ANN401 - new tool_request.input JSON or None to keep prior - *, - interrupt: ToolRequestPart, - tool: Callable[..., Any], - resumed_metadata: dict[str, Any] | None = None, -) -> ToolRequestPart: - """Build a restart ``ToolRequestPart`` for an interrupted tool call. - - Thin wrapper around ``tool.restart(...)`` on the registered tool callable (from - :func:`define_tool` or :meth:`Genkit.tool`). Use when you want a module-level helper - parallel to :func:`respond_to_interrupt`. - - Args: - replace_input: New ``tool_request.input`` for the redo, or ``None`` to keep the - interrupted request's input. - interrupt: The interrupted ``ToolRequestPart`` (e.g. from ``response.interrupts``). - tool: The same tool function / wrapper that was registered (must expose ``.restart``). - resumed_metadata: Passed through to :attr:`ToolRunContext.resumed_metadata`. - - Returns: - A ``ToolRequestPart`` suitable for ``generate(..., resume_restart=[...])`` / message history. - """ - restart = getattr(tool, 'restart', None) - if restart is None: - raise TypeError(f'{tool!r} has no restart method; pass a tool from define_tool or Genkit.tool') - out = restart( - replace_input, - interrupt=interrupt, - resumed_metadata=resumed_metadata, - ) - if not isinstance(out, ToolRequestPart): - msg = f'Expected tool.restart() to return ToolRequestPart, got {type(out)!r}' - raise TypeError(msg) - return out - - async def run_tool_after_restart(tool: Action[Any, Any, Any], restart_trp: ToolRequestPart) -> Part: """Run a tool for ``resume_restart``: applies ``resumed`` / ``replacedInput`` from metadata. @@ -253,12 +275,12 @@ def _get_func_description(func: Callable[..., Any], description: str | None = No def _define_tool( registry: Registry, - func: Callable[P, T], + func: Callable[..., Any], name: str | None = None, description: str | None = None, *, input_schema: type[BaseModel] | dict[str, object] | None = None, -) -> Callable[P, T]: +) -> Tool: """Register a function as a tool. Normally, the input_schema and output_schem are inferred from func. However, @@ -268,29 +290,26 @@ def _define_tool( In that case, the app developer can pass in an input_schema to override the inferred schema. This will ensure that the model requesting the tool will see the correct input shape. """ - # All Python functions have __name__, but ty is strict about Callable protocol if not inspect.iscoroutinefunction(func): - raise TypeError(f'Tool function must be async. Got sync function: {func.__name__}') # ty: ignore[unresolved-attribute] + raise TypeError(f'Tool function must be async. Got sync function: {getattr(func, "__name__", repr(func))}') tool_name = name if name is not None else getattr(func, '__name__', 'unnamed_tool') tool_description = _get_func_description(func, description) input_spec = inspect.getfullargspec(func) - func_any = cast(Callable[..., Any], func) - async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 - arity dispatch; args/return follow registered tool # Dynamic dispatch by arity; payload types follow the registered tool (not expressible here). match len(input_spec.args): case 0: - return await func_any() + return await func() case 1: - return await func_any(args[0]) + return await func(args[0]) case 2: # Read from context variables for resumed metadata resumed_meta = _tool_resumed_metadata.get() original_input = _tool_original_input.get() - return await func_any( + return await func( args[0], ToolRunContext( cast(ActionRunContext, args[1]), @@ -311,74 +330,21 @@ async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 - arity dispatch; if input_schema is not None: action._override_input_schema(input_schema) - @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any: # noqa: ANN401 - # Return type is the tool's declared output; ParamSpec preserves call shape only. - action_any = cast(Any, action) - return (await action_any.run(*args, **kwargs)).response - - # Add restart method to the wrapper (use respond_to_interrupt for interrupt replies). - def restart( - replace_input: Any | None = None, # noqa: ANN401 - same as restart_interrupted_tool.replace_input - *, - interrupt: ToolRequestPart, - resumed_metadata: dict[str, Any] | None = None, - ) -> ToolRequestPart: - """Create a restart request for an interrupted tool call. - - Args: - replace_input: Optional new ``tool_request.input`` for this run (previous input is - stored in ``metadata.replacedInput`` when this is set). - interrupt: The interrupted ``ToolRequestPart`` (e.g. from ``response.interrupts``). - resumed_metadata: Passed to the tool as :attr:`ToolRunContext.resumed_metadata`. - - Returns: - A ``ToolRequestPart`` for ``resume_restart`` / message history. - - Example: - ``pay_invoice.restart({**trp.tool_request.input, "confirmed": True}, interrupt=trp,`` - ``resumed_metadata={"by": "bob"})`` - """ - tool_req = interrupt.tool_request - if tool_req.name != tool_name: - raise ValueError(f"Interrupt is for tool '{tool_req.name}', not '{tool_name}'") - - existing_meta = interrupt.metadata or {} - new_meta: dict[str, Any] = dict(existing_meta) if existing_meta else {} - - # Mark as resumed - new_meta['resumed'] = resumed_metadata if resumed_metadata is not None else True - - # Store original input if replacing - new_input = tool_req.input - if replace_input is not None: - new_meta['replacedInput'] = tool_req.input - new_input = replace_input - - # Remove interrupt marker - if 'interrupt' in new_meta: - del new_meta['interrupt'] - - return ToolRequestPart( - tool_request=ToolRequest( - name=tool_req.name, - ref=tool_req.ref, - input=new_input, - ), - metadata=new_meta, - ) - - wrapper.restart = restart # type: ignore[attr-defined] - - return cast(Callable[P, T], wrapper) + return Tool( + name=tool_name, + description=tool_description or '', + input_schema=action.input_schema, + output_schema=action.output_schema, + action=action, + ) def define_tool( registry: Registry, - func: Callable[P, T], + func: Callable[..., Any], name: str | None = None, description: str | None = None, -) -> Callable[P, T]: +) -> Tool: """Register a function as a tool. Tool input/output JSON Schemas are inferred from ``func`` (first parameter and return type). @@ -402,7 +368,7 @@ def define_interrupt( description: str | None = None, request_metadata: dict[str, Any] | Callable[[Any], dict[str, Any]] | None = None, # noqa: ANN401 input_schema: type[BaseModel] | dict[str, object] | None = None, -) -> Callable[..., Any]: +) -> Tool: """Register a tool that always interrupts execution. An interrupt tool is a special tool that always raises :class:`Interrupt` with diff --git a/py/packages/genkit/src/genkit/_core/_flow.py b/py/packages/genkit/src/genkit/_core/_flow.py index 2958832caf..e18371583f 100644 --- a/py/packages/genkit/src/genkit/_core/_flow.py +++ b/py/packages/genkit/src/genkit/_core/_flow.py @@ -67,9 +67,9 @@ def define_flow( """Register an async function as a flow action.""" # All Python functions have __name__, but ty is strict about Callable protocol if not inspect.iscoroutinefunction(func): - raise TypeError(f'Flow must be async: {func.__name__}') # ty: ignore[unresolved-attribute] + raise TypeError(f'Flow must be async: {getattr(func, "__name__", repr(func))}') - flow_name = name or func.__name__ # ty: ignore[unresolved-attribute] + flow_name = name or getattr(func, '__name__', None) or 'unnamed_flow' return registry.register_action( name=flow_name, kind=ActionKind.FLOW, diff --git a/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py b/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py index 54f9cf4203..dca72340b2 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py @@ -3,8 +3,6 @@ """Unit tests for private helpers in genkit._ai._generate (interrupt / resume).""" -import pytest - from genkit._ai._generate import ( _find_corresponding_restart, _find_corresponding_tool_response, @@ -13,7 +11,7 @@ ) from genkit._ai._tools import Interrupt from genkit._core._error import GenkitError -from genkit._core._typing import Part, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart +from genkit._core._typing import ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart def test_find_corresponding_restart_matches_name_and_ref() -> None: diff --git a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py index 5fc3f171ab..e05b273328 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py @@ -11,8 +11,8 @@ from genkit import Genkit, Message, ModelResponse from genkit._ai._generate import _resolve_resumed_tool_request, generate_action -from genkit._ai._tools import Interrupt, ToolRunContext, respond_to_interrupt from genkit._ai._testing import define_programmable_model +from genkit._ai._tools import Interrupt, ToolRunContext, respond_to_interrupt from genkit._core._error import GenkitError from genkit._core._model import GenerateActionOptions from genkit._core._typing import FinishReason, Part, Resume @@ -23,7 +23,9 @@ def _wire(messages: list[Message]) -> list[dict[str, Any]]: return [m.model_dump(mode='json', exclude_none=True, by_alias=True) for m in messages] -def _gen_opts(ai: Genkit, *, tools: list[str], messages: list[Message], resume: Resume | None = None) -> GenerateActionOptions: +def _gen_opts( + ai: Genkit, *, tools: list[str], messages: list[Message], resume: Resume | None = None +) -> GenerateActionOptions: return GenerateActionOptions( model='programmableModel', messages=messages, @@ -56,16 +58,14 @@ async def u2(_: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 pm.responses.append( ModelResponse( finish_reason=FinishReason.STOP, - message=Message.model_validate( - { - 'role': 'model', - 'content': [ - {'text': 'go'}, - {'toolRequest': {'ref': '1', 'name': 'u1', 'input': {}}}, - {'toolRequest': {'ref': '2', 'name': 'u2', 'input': {}}}, - ], - } - ), + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'go'}, + {'toolRequest': {'ref': '1', 'name': 'u1', 'input': {}}}, + {'toolRequest': {'ref': '2', 'name': 'u2', 'input': {}}}, + ], + }), ) ) pm.responses.append( @@ -77,7 +77,9 @@ async def u2(_: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 r = await generate_action( ai.registry, - _gen_opts(ai, tools=['u1', 'u2'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + _gen_opts( + ai, tools=['u1', 'u2'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})] + ), ) assert seen == [(False, None, None), (False, None, None)] assert _wire(r.messages) == [ @@ -122,15 +124,13 @@ async def intr(_: dict) -> str: # noqa: ARG001 pm.responses.append( ModelResponse( finish_reason=FinishReason.STOP, - message=Message.model_validate( - { - 'role': 'model', - 'content': [ - {'text': 'call'}, - {'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}}, - ], - } - ), + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'call'}, + {'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}}, + ], + }), ) ) @@ -173,15 +173,13 @@ async def intr(_: dict) -> str: # noqa: ARG001 pm.responses.append( ModelResponse( finish_reason=FinishReason.STOP, - message=Message.model_validate( - { - 'role': 'model', - 'content': [ - {'text': 'call'}, - {'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}}, - ], - } - ), + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'call'}, + {'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}}, + ], + }), ) ) pm.responses.append( @@ -271,21 +269,19 @@ async def bank_transfer(inp: dict) -> int: pm.responses.append( ModelResponse( finish_reason=FinishReason.STOP, - message=Message.model_validate( - { - 'role': 'model', - 'content': [ - {'text': 't'}, - { - 'toolRequest': { - 'ref': 'g', - 'name': 'bank_transfer', - 'input': {'preapproved': False}, - }, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 't'}, + { + 'toolRequest': { + 'ref': 'g', + 'name': 'bank_transfer', + 'input': {'preapproved': False}, }, - ], - } - ), + }, + ], + }), ) ) r_fail = await generate_action( @@ -321,21 +317,19 @@ async def bank_transfer(inp: dict) -> int: pm.responses.append( ModelResponse( finish_reason=FinishReason.STOP, - message=Message.model_validate( - { - 'role': 'model', - 'content': [ - {'text': 't2'}, - { - 'toolRequest': { - 'ref': 'g2', - 'name': 'bank_transfer', - 'input': {'preapproved': True}, - }, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 't2'}, + { + 'toolRequest': { + 'ref': 'g2', + 'name': 'bank_transfer', + 'input': {'preapproved': True}, }, - ], - } - ), + }, + ], + }), ) ) pm.responses.append( @@ -404,15 +398,13 @@ async def pay(inp: dict) -> str: pm.responses.append( ModelResponse( finish_reason=FinishReason.STOP, - message=Message.model_validate( - { - 'role': 'model', - 'content': [ - {'text': 'x'}, - {'toolRequest': {'ref': 'p1', 'name': 'pay', 'input': {}}}, - ], - } - ), + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'x'}, + {'toolRequest': {'ref': 'p1', 'name': 'pay', 'input': {}}}, + ], + }), ) ) pm.responses.append( @@ -504,16 +496,14 @@ async def b_tool(inp: dict) -> str: pm.responses.append( ModelResponse( finish_reason=FinishReason.STOP, - message=Message.model_validate( - { - 'role': 'model', - 'content': [ - {'text': 'both'}, - {'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}}, - {'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}}, - ], - } - ), + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'both'}, + {'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}}, + {'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}}, + ], + }), ) ) pm.responses.append( @@ -525,7 +515,9 @@ async def b_tool(inp: dict) -> str: first = await generate_action( ai.registry, - _gen_opts(ai, tools=['a', 'b'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + _gen_opts( + ai, tools=['a', 'b'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})] + ), ) assert first.finish_reason == FinishReason.INTERRUPTED assert _wire(first.messages) == [ @@ -610,12 +602,10 @@ async def test_pending_output_trp_yields_tool_response_with_source_pending() -> directly, not a full message list. """ ai = Genkit() - part = Part.model_validate( - { - 'toolRequest': {'name': 't', 'ref': 'r', 'input': {}}, - 'metadata': {'pendingOutput': 123}, - } - ) + part = Part.model_validate({ + 'toolRequest': {'name': 't', 'ref': 'r', 'input': {}}, + 'metadata': {'pendingOutput': 123}, + }) _, resp = await _resolve_resumed_tool_request( ai.registry, GenerateActionOptions( @@ -645,17 +635,15 @@ async def test_resume_without_matching_replies_raises() -> None: model='programmableModel', messages=[ Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]}), - Message.model_validate( - { - 'role': 'model', - 'content': [ - { - 'toolRequest': {'ref': 'z', 'name': 'missing', 'input': {}}, - 'metadata': {'interrupt': True}, - }, - ], - } - ), + Message.model_validate({ + 'role': 'model', + 'content': [ + { + 'toolRequest': {'ref': 'z', 'name': 'missing', 'input': {}}, + 'metadata': {'interrupt': True}, + }, + ], + }), ], resume=Resume(), ), @@ -681,4 +669,3 @@ async def test_resume_requires_last_message_model_with_tool_requests() -> None: ), ) assert 'model' in ei.value.original_message.lower() - diff --git a/py/packages/genkit/tests/genkit/ai/generate_test.py b/py/packages/genkit/tests/genkit/ai/generate_test.py index f24da35c0e..ecb05be823 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_test.py @@ -16,13 +16,13 @@ from genkit import ActionKind, Document, Genkit, Message, ModelResponse, ModelResponseChunk from genkit._ai._generate import generate_action -from genkit._ai._tools import Interrupt from genkit._ai._model import text_from_content, text_from_message from genkit._ai._testing import ( ProgrammableModel, define_echo_model, define_programmable_model, ) +from genkit._ai._tools import Interrupt from genkit._core._action import ActionRunContext from genkit._core._model import GenerateActionOptions, ModelRequest from genkit._core._typing import ( diff --git a/py/samples/tool-interrupts/src/approval_example.py b/py/samples/tool-interrupts/src/approval_example.py index 7ed6b684d5..4d10022c07 100644 --- a/py/samples/tool-interrupts/src/approval_example.py +++ b/py/samples/tool-interrupts/src/approval_example.py @@ -17,7 +17,7 @@ """**Bank transfer approval** — human-in-the-loop before a transfer tool finishes. The model calls ``request_transfer``; the CLI asks **approve (y)** or **decline (n)**. -**Approve** → ``restart_interrupted_tool`` / ``resume_restart`` so the tool **runs again** +**Approve** → ``tool.restart(...)`` / ``resume_restart`` so the tool **runs again** with :class:`~genkit.ToolRunContext.is_resumed`. **Decline** → ``respond_to_interrupt`` / ``resume_respond`` (no second tool run). @@ -38,7 +38,6 @@ Interrupt, ToolRunContext, respond_to_interrupt, - restart_interrupted_tool, ) from genkit.model import ModelResponse from genkit.plugins.google_genai import GoogleAI # pyright: ignore[reportMissingImports] @@ -147,7 +146,7 @@ async def interactive_restart_cli() -> None: while response.interrupts: interrupt = response.interrupts[0] name = interrupt.tool_request.name - if name != request_transfer.__name__: + if name != request_transfer.name: _print_unexpected_tool(name) return @@ -157,9 +156,8 @@ async def interactive_restart_cli() -> None: ans = input('Approve transfer? [y/N]: ').strip().lower() if ans in ('y', 'yes'): - restart = restart_interrupted_tool( + restart = request_transfer.restart( interrupt=interrupt, - tool=request_transfer, resumed_metadata={'via': 'cli', 'path': 'restart'}, ) response = await ai.generate( diff --git a/py/samples/tool-interrupts/src/respond_example.py b/py/samples/tool-interrupts/src/respond_example.py index 83cecda494..5a3a9a979d 100755 --- a/py/samples/tool-interrupts/src/respond_example.py +++ b/py/samples/tool-interrupts/src/respond_example.py @@ -116,7 +116,7 @@ def show(label: str, r: ModelResponse) -> None: interrupt = response.interrupts[0] name = interrupt.tool_request.name - if name != present_questions.__name__: + if name != present_questions.name: print(f'Unexpected tool: {name!r}') return From ed69312cff501e5c4840296600cd66449a624e86 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Sun, 29 Mar 2026 18:21:37 -0700 Subject: [PATCH 09/15] clean up tool design --- py/packages/genkit/src/genkit/_ai/_aio.py | 20 ----- .../genkit/src/genkit/_ai/_generate.py | 10 +-- py/packages/genkit/src/genkit/_ai/_prompt.py | 6 +- py/packages/genkit/src/genkit/_ai/_tools.py | 64 +++++++++++---- .../genkit/src/genkit/_core/_registry.py | 6 -- .../genkit/tests/genkit/ai/_tools_test.py | 78 ++++++++++++++++++- .../genkit/tests/genkit/ai/genkit_api_test.py | 24 +----- .../genkit/tests/genkit/ai/prompt_test.py | 21 +---- .../plugins/google_genai/models/gemini.py | 3 +- py/samples/dynamic-tools/README.md | 18 ----- py/samples/dynamic-tools/pyproject.toml | 18 ----- py/samples/dynamic-tools/src/main.py | 74 ------------------ py/uv.lock | 22 ------ 13 files changed, 138 insertions(+), 226 deletions(-) delete mode 100644 py/samples/dynamic-tools/README.md delete mode 100644 py/samples/dynamic-tools/pyproject.toml delete mode 100644 py/samples/dynamic-tools/src/main.py diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index 1fc8844306..f92e070e14 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -1110,26 +1110,6 @@ def current_context() -> dict[str, Any] | None: """Get the current execution context, or None if not in an action.""" return ActionRunContext._current_context() # pyright: ignore[reportPrivateUsage] - def dynamic_tool( - self, - *, - name: str, - fn: Callable[..., object], - description: str | None = None, - metadata: dict[str, object] | None = None, - ) -> Action: - """Create an unregistered tool action for passing directly to generate().""" - tool_meta: dict[str, object] = metadata.copy() if metadata else {} - tool_meta['type'] = 'tool' - tool_meta['dynamic'] = True - return Action( - kind=ActionKind.TOOL, - name=name, - fn=fn, # type: ignore[arg-type] # dynamic tools may be sync - description=description, - metadata=tool_meta, - ) - async def flush_tracing(self) -> None: """Flush all pending trace spans to exporters.""" provider = trace_api.get_tracer_provider() diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index 4e10161f49..fccec76d64 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -58,7 +58,7 @@ logger = get_logger(__name__) -def tools_to_action_refs( +def tools_to_action_names( tools: Sequence[str | Tool] | None, ) -> list[str] | None: """Normalize tool arguments to registry names for :class:`GenerateActionOptions`. @@ -68,13 +68,13 @@ def tools_to_action_refs( """ if tools is None: return None - refs: list[str] = [] + names: list[str] = [] for t in tools: if isinstance(t, str): - refs.append(t) + names.append(t) else: - refs.append(t.name) - return refs + names.append(t.name) + return names # Matches data URIs: everything up to the first comma is the media-type + diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index 4d453831b0..7bcfef00c3 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -38,7 +38,7 @@ generate_action, resolve_tool, to_tool_definition, - tools_to_action_refs, + tools_to_action_names, ) from genkit._ai._model import ( Message, @@ -506,7 +506,7 @@ def _or(opt_val: Any, default: Any) -> Any: # noqa: ANN401 model=model_name, messages=resolved_msgs, # type: ignore[arg-type] config=prompt_options.config, - tools=tools_to_action_refs(prompt_options.tools), + tools=tools_to_action_names(prompt_options.tools), return_tool_requests=prompt_options.return_tool_requests, tool_choice=prompt_options.tool_choice, output=output_config, @@ -676,7 +676,7 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig) ) # Convert tool refs (str name or Tool object) to string names for GenerateActionOptions - tools_refs = tools_to_action_refs(options.tools) + tools_refs = tools_to_action_names(options.tools) return GenerateActionOptions( model=model, diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index 16ca237230..a42b33e779 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -21,7 +21,7 @@ from contextvars import ContextVar from typing import Any, cast -from pydantic import BaseModel, PrivateAttr +from pydantic import BaseModel from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._error import GenkitError @@ -29,19 +29,61 @@ from genkit._core._typing import Part, ToolDefinition, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart -class Tool(ToolDefinition): - """A registered tool: its :class:`ToolDefinition` wire format plus a callable and ``restart``.""" +class Tool: + """A registered tool: a callable handle backed by an :class:`~genkit._core._action.Action`. - _action: Any = PrivateAttr() + Obtain instances via :func:`define_tool`, :func:`define_interrupt`, or the + ``@ai.tool`` decorator rather than constructing directly. + """ - def __init__(self, *, action: Any, **data: Any) -> None: # noqa: ANN401 - super().__init__(**data) + def __init__(self, action: Action) -> None: self._action = action + # ------------------------------------------------------------------ + # Properties that delegate to the underlying Action + # ------------------------------------------------------------------ + + @property + def name(self) -> str: + """Tool name (registry key).""" + return self._action.name + + @property + def description(self) -> str: + """Human-readable description sent to the model.""" + return self._action.description or '' + + @property + def input_schema(self) -> dict[str, object] | None: + """JSON Schema for the tool's input, as sent on the wire.""" + return self._action.input_schema + + @property + def output_schema(self) -> dict[str, object] | None: + """JSON Schema for the tool's output.""" + return self._action.output_schema + + def definition(self) -> ToolDefinition: + """Return the wire-format :class:`ToolDefinition` for this tool.""" + return ToolDefinition( + name=self.name, + description=self.description, + input_schema=self.input_schema, + output_schema=self.output_schema, + ) + + # ------------------------------------------------------------------ + # Execution + # ------------------------------------------------------------------ + async def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 - """Call the underlying action.""" + """Run the tool and return the unwrapped response value.""" return (await self._action.run(*args, **kwargs)).response + # ------------------------------------------------------------------ + # Interrupt / restart helpers + # ------------------------------------------------------------------ + def restart( self, replace_input: Any | None = None, # noqa: ANN401 @@ -330,13 +372,7 @@ async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 - arity dispatch; if input_schema is not None: action._override_input_schema(input_schema) - return Tool( - name=tool_name, - description=tool_description or '', - input_schema=action.input_schema, - output_schema=action.output_schema, - action=action, - ) + return Tool(action) def define_tool( diff --git a/py/packages/genkit/src/genkit/_core/_registry.py b/py/packages/genkit/src/genkit/_core/_registry.py index 9df7c99778..b4731900be 100644 --- a/py/packages/genkit/src/genkit/_core/_registry.py +++ b/py/packages/genkit/src/genkit/_core/_registry.py @@ -41,7 +41,6 @@ ModelResponseChunk, ) from genkit._core._plugin import Plugin -from genkit._core._schema import to_json_schema from genkit._core._typing import ( EmbedRequest, EmbedResponse, @@ -138,8 +137,6 @@ def register_action( description: str | None = None, metadata: dict[str, object] | None = None, span_metadata: dict[str, SpanAttributeValue] | None = None, - *, - output_schema: type[BaseModel] | dict[str, object] | None = None, ) -> Action[InputT, OutputT, ChunkT]: """Register a new action with the registry. @@ -155,7 +152,6 @@ def register_action( description: Optional human-readable description of the action. metadata: Optional dictionary of metadata about the action. span_metadata: Optional dictionary of tracing span metadata. - output_schema: Optional explicit output JSON schema (Pydantic model or dict). Returns: The newly created and registered Action instance. @@ -169,8 +165,6 @@ def register_action( metadata=metadata, span_metadata=span_metadata, ) - if output_schema is not None: - action.output_schema = to_json_schema(output_schema) action_typed = cast(Action[InputT, OutputT, ChunkT], action) with self._lock: if kind not in self._entries: diff --git a/py/packages/genkit/tests/genkit/ai/_tools_test.py b/py/packages/genkit/tests/genkit/ai/_tools_test.py index 818874e9df..4d02e742dc 100644 --- a/py/packages/genkit/tests/genkit/ai/_tools_test.py +++ b/py/packages/genkit/tests/genkit/ai/_tools_test.py @@ -12,9 +12,10 @@ ToolRunContext, _tool_original_input, _tool_resumed_metadata, + respond_to_interrupt, ) from genkit._core._error import GenkitError -from genkit._core._typing import ToolRequest, ToolRequestPart +from genkit._core._typing import ToolRequest, ToolRequestPart, ToolResponsePart @pytest.mark.asyncio @@ -189,3 +190,78 @@ async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 await run_tool_after_restart(action, restart_trp) msg = ei.value.original_message.lower() assert 'restart' in msg or 'nested' in msg or 'not supported' in msg + + +# --------------------------------------------------------------------------- +# Wire-format tests: respond_to_interrupt +# --------------------------------------------------------------------------- + + +def test_respond_to_interrupt_wire_format_basic() -> None: + """respond_to_interrupt produces a ToolResponsePart with matching ref/name and interruptResponse metadata.""" + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='ask_user', ref='ref-abc', input={'question': 'ok?'}), + metadata={'interrupt': {'reason': 'needs_approval'}}, + ) + + result = respond_to_interrupt('yes', interrupt=interrupt_trp) + + assert isinstance(result, ToolResponsePart) + assert result.tool_response.name == 'ask_user' + assert result.tool_response.ref == 'ref-abc' + assert result.tool_response.output == 'yes' + assert result.metadata is not None + assert result.metadata.get('interruptResponse') is True + + +def test_respond_to_interrupt_wire_format_with_metadata() -> None: + """respond_to_interrupt attaches custom metadata under interruptResponse key.""" + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='confirm', ref='ref-xyz', input={}), + metadata={'interrupt': True}, + ) + + result = respond_to_interrupt({'approved': True}, interrupt=interrupt_trp, metadata={'by': 'admin'}) + + assert result.tool_response.ref == 'ref-xyz' + assert result.tool_response.output == {'approved': True} + assert result.metadata is not None + assert result.metadata.get('interruptResponse') == {'by': 'admin'} + + +def test_restart_preserves_ref_on_wire() -> None: + """restart() preserves the original tool_request.ref so the resumed TRP can be correlated.""" + ai = Genkit() + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: # noqa: ARG001 + return 'ok' + + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='pay', ref='corr-id-1', input={'amount': 50}), + metadata={'interrupt': True}, + ) + out = pay.restart(None, interrupt=interrupt_trp) + + assert out.tool_request.ref == 'corr-id-1' + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_response_preserves_ref() -> None: + """run_tool_after_restart produces a ToolResponsePart whose ref matches the restart TRP's ref.""" + ai = Genkit() + + @ai.tool(name='t_ref') + async def t_ref(inp: dict) -> str: # noqa: ARG001 + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t_ref') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t_ref', ref='wire-ref-99', input={}), + metadata={'resumed': True}, + ) + part = await run_tool_after_restart(action, restart_trp) + assert part.root.tool_response.ref == 'wire-ref-99' + diff --git a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py index c47a4e6cd0..6b6cbe0474 100644 --- a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py +++ b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py @@ -13,7 +13,7 @@ from opentelemetry.sdk.trace import TracerProvider from genkit import Genkit -from genkit._core._action import Action, ActionKind, _action_context +from genkit._core._action import _action_context from genkit._core._typing import Operation @@ -40,28 +40,6 @@ def sync_fn() -> str: await ai.run(name='test3', fn=sync_fn) # type: ignore[arg-type] -@pytest.mark.asyncio -async def test_genkit_dynamic_tool() -> None: - """Test Genkit.dynamic_tool method.""" - ai = Genkit() - - async def my_tool(x: int) -> int: - return x + 1 - - tool = ai.dynamic_tool(name='my_tool', fn=my_tool, description='increment x') - - assert isinstance(tool, Action) - assert tool.kind == ActionKind.TOOL - assert tool.name == 'my_tool' - assert tool.description == 'increment x' - assert tool.metadata.get('type') == 'tool' - assert tool.metadata.get('dynamic') is True - - # Execution - resp = await tool.run(5) - assert resp.response == 6 - - @pytest.mark.asyncio async def test_genkit_check_operation() -> None: """Test Genkit.check_operation method.""" diff --git a/py/packages/genkit/tests/genkit/ai/prompt_test.py b/py/packages/genkit/tests/genkit/ai/prompt_test.py index 24c31a3dc7..8a62fd47da 100644 --- a/py/packages/genkit/tests/genkit/ai/prompt_test.py +++ b/py/packages/genkit/tests/genkit/ai/prompt_test.py @@ -241,26 +241,7 @@ async def test_prompt_rendering_dotprompt( """Test prompt rendering.""" ai, *_ = setup_test() - my_prompt = ai.define_prompt( - model=prompt.get('model'), - config=prompt.get('config'), - description=prompt.get('description'), - input_schema=prompt.get('input_schema'), - system=prompt.get('system'), - prompt=prompt.get('prompt'), - messages=prompt.get('messages'), - output_format=prompt.get('output_format'), - output_content_type=prompt.get('output_content_type'), - output_instructions=prompt.get('output_instructions'), - output_constrained=prompt.get('output_constrained'), - max_turns=prompt.get('max_turns'), - return_tool_requests=prompt.get('return_tool_requests'), - metadata=prompt.get('metadata'), - tools=prompt.get('tools'), - tool_choice=prompt.get('tool_choice'), - use=prompt.get('use'), - docs=prompt.get('docs'), - ) + my_prompt = ai.define_prompt(**prompt) # New API: use opts parameter to pass config and context response = await my_prompt(input, config=input_option, context=context) diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index a7db7f62ea..5c2563a594 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -1078,8 +1078,7 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool: Genai tool compatible with Gemini API. """ params = self._convert_schema_property(tool.input_schema) - # Fix for no-arg tools: parameters cannot be None if we want the tool to be callable? - # Actually Google GenAI expects type=OBJECT for params usually. + # Empty params: Gemini requires type=OBJECT even for no-arg tools. if not params: params = genai_types.Schema(type=genai_types.Type.OBJECT, properties={}) # Gemini rejects scalar/array root schemas. Tool params must be OBJECT with diff --git a/py/samples/dynamic-tools/README.md b/py/samples/dynamic-tools/README.md deleted file mode 100644 index 5377eac513..0000000000 --- a/py/samples/dynamic-tools/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# Dynamic Tools - -Learn two related ideas: - -- `ai.dynamic_tool()` creates a tool at runtime. -- `ai.run()` traces a plain async function as a named step. - -```bash -export GEMINI_API_KEY=your-api-key -uv sync -uv run src/main.py -``` - -To inspect the same flows in Dev UI: - -```bash -genkit start -- uv run src/main.py -``` diff --git a/py/samples/dynamic-tools/pyproject.toml b/py/samples/dynamic-tools/pyproject.toml deleted file mode 100644 index 9660490674..0000000000 --- a/py/samples/dynamic-tools/pyproject.toml +++ /dev/null @@ -1,18 +0,0 @@ -[project] -name = "dynamic-tools" -version = "0.2.0" -requires-python = ">=3.10" -dependencies = [ - "genkit", - "genkit-plugin-google-genai", - "pydantic>=2.0.0", - "structlog>=24.0.0", - "uvloop>=0.21.0", -] - -[build-system] -build-backend = "hatchling.build" -requires = ["hatchling"] - -[tool.hatch.build.targets.wheel] -packages = ["src"] diff --git a/py/samples/dynamic-tools/src/main.py b/py/samples/dynamic-tools/src/main.py deleted file mode 100644 index bb200eb1aa..0000000000 --- a/py/samples/dynamic-tools/src/main.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Dynamic tools - create tools at runtime and trace plain functions.""" - -from pydantic import BaseModel, Field - -from genkit import Genkit -from genkit.plugins.google_genai import GoogleAI - -ai = Genkit(plugins=[GoogleAI()], model='googleai/gemini-2.5-flash') - - -class DynamicToolInput(BaseModel): - """Input for runtime tool creation.""" - - value: int = Field(default=5, description='Value to square with a runtime tool') - - -class RunStepInput(BaseModel): - """Input for traced step demo.""" - - text: str = Field(default='hello dynamic tools', description='Text to process with traced steps') - - -@ai.flow() -async def dynamic_tool_demo(input: DynamicToolInput) -> str: - """Create a tool at runtime and call it immediately.""" - - async def square(value: int) -> int: - return value * value - - tool = ai.dynamic_tool(name='square', description='Square a number', fn=square) - result = await tool.run(input.value) - return f'{input.value} squared is {result.response}' - - -@ai.flow() -async def run_step_demo(input: RunStepInput) -> dict[str, int | str]: - """Wrap plain async functions in traceable `ai.run()` steps.""" - - async def normalize() -> str: - return input.text.strip().lower() - - async def count_words() -> int: - return len(normalized.split()) - - normalized = await ai.run(name='normalize_text', fn=normalize) - word_count = await ai.run(name='count_words', fn=count_words) - return {'normalized': normalized, 'word_count': word_count} - - -async def main() -> None: - """Run both dynamic tool demos once.""" - - print(await dynamic_tool_demo(DynamicToolInput())) # noqa: T201 - print(await run_step_demo(RunStepInput())) # noqa: T201 - - -if __name__ == '__main__': - ai.run_main(main()) diff --git a/py/uv.lock b/py/uv.lock index adcfe31bcb..ae1837a0ad 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -12,7 +12,6 @@ resolution-markers = [ [manifest] members = [ "context", - "dynamic-tools", "evaluators", "fastapi-bugbot", "flask-hello", @@ -1223,27 +1222,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/09/d09dfaa2110884284be6006b7586ea519f7391de58ed5428f2bf457bcd03/dotpromptz_handlebars-0.1.8-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f23498821610d443a67c860922aba00d20bdd80b8421bfef0ceff07b713f8198", size = 666257, upload-time = "2026-01-30T06:44:46.929Z" }, ] -[[package]] -name = "dynamic-tools" -version = "0.2.0" -source = { editable = "samples/dynamic-tools" } -dependencies = [ - { name = "genkit" }, - { name = "genkit-plugin-google-genai" }, - { name = "pydantic" }, - { name = "structlog" }, - { name = "uvloop" }, -] - -[package.metadata] -requires-dist = [ - { name = "genkit", editable = "packages/genkit" }, - { name = "genkit-plugin-google-genai", editable = "plugins/google-genai" }, - { name = "pydantic", specifier = ">=2.0.0" }, - { name = "structlog", specifier = ">=24.0.0" }, - { name = "uvloop", specifier = ">=0.21.0" }, -] - [[package]] name = "email-validator" version = "2.3.0" From 8398540d6d7dce621c6f45fcd5faf8a5d120fbf4 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 30 Mar 2026 15:54:55 -0700 Subject: [PATCH 10/15] fix(py): Fix pyproject.toml on fastapi plugin --- py/plugins/fastapi/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/plugins/fastapi/pyproject.toml b/py/plugins/fastapi/pyproject.toml index d55139a55f..9438a86c41 100644 --- a/py/plugins/fastapi/pyproject.toml +++ b/py/plugins/fastapi/pyproject.toml @@ -66,4 +66,4 @@ build-backend = "hatchling.build" requires = ["hatchling"] [tool.hatch.build.targets.wheel] -packages = ["src/genkit", "src/genkit/plugins"] +only-include = ["src/genkit/plugins/fastapi"] From 1374dae9168b872e20ecafd8e243513a8f20bc77 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 30 Mar 2026 23:35:25 -0700 Subject: [PATCH 11/15] address comments --- py/packages/genkit/src/genkit/_ai/_aio.py | 15 +- .../genkit/src/genkit/_ai/_generate.py | 79 +++++----- py/packages/genkit/src/genkit/_ai/_prompt.py | 4 +- py/packages/genkit/src/genkit/_ai/_tools.py | 36 +++-- .../genkit/tests/genkit/ai/_tools_test.py | 38 ++++- .../ai/generate_interrupt_resume_test.py | 135 +++++++++++++++--- .../genkit/tests/genkit/ai/prompt_test.py | 4 +- .../tool-interrupts/src/approval_example.py | 2 +- .../tool-interrupts/src/respond_example.py | 3 +- 9 files changed, 212 insertions(+), 104 deletions(-) diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index f92e070e14..61b4b93cc3 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -841,7 +841,12 @@ async def generate( use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, ) -> ModelResponse[Any]: - """Generate text or structured data using a language model.""" + """Generate text or structured data using a language model. + + ``tools`` is typed as ``Sequence`` rather than ``list`` because ``Sequence`` + is covariant: ``list[Tool]`` or ``list[str]`` are both assignable to + ``Sequence[str | Tool]``, but not to ``list[str | Tool]``. + """ return await generate_action( self.registry, await to_generate_action_options( @@ -952,13 +957,7 @@ def generate_stream( docs: list[Document] | None = None, timeout: float | None = None, ) -> ModelStreamResponse[Any]: - """Stream generated text, returning a ModelStreamResponse with .stream and .response. - - Middleware (``use=``) uses the same signatures as :meth:`generate`: - - Simple: ``(req, ctx, next)`` — 3 params - - Streaming-aware: ``(req, ctx, on_chunk, next)`` — 4 params - The framework auto-detects by parameter count. Both work with generate_stream. - """ + """Stream generated text, returning a ModelStreamResponse with .stream and .response.""" channel: Channel[ModelResponseChunk, ModelResponse[Any]] = Channel(timeout=timeout) async def _run_generate() -> ModelResponse[Any]: diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index fccec76d64..d00f8c30ac 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -61,10 +61,10 @@ def tools_to_action_names( tools: Sequence[str | Tool] | None, ) -> list[str] | None: - """Normalize tool arguments to registry names for :class:`GenerateActionOptions`. + """Normalize tool arguments to registry names for GenerateActionOptions. - Each item may be a tool name (``str``) or a :class:`~genkit.Tool` returned by - :meth:`~genkit.Genkit.tool`. + Each item may be a tool name (``str``) or a Tool returned by + Genkit.tool(). """ if tools is None: return None @@ -675,7 +675,7 @@ def _to_pending_response(request: ToolRequestPart, response: ToolResponsePart) - def _interrupt_from_tool_exc(exc: BaseException) -> Interrupt | None: - """If ``exc`` is (or wraps) :class:`~genkit._ai._tools.Interrupt`, return that interrupt.""" + """If ``exc`` is (or wraps) an Interrupt exception, return that interrupt.""" if isinstance(exc, Interrupt): return exc if isinstance(exc, GenkitError) and exc.cause is not None and isinstance(exc.cause, Interrupt): @@ -724,12 +724,11 @@ async def _resolve_tool_request( async def resolve_tool(registry: Registry, tool_name: str) -> Action: - """Resolve a :class:`~genkit.ActionKind.TOOL` action by name from the registry.""" + """Resolve a tool action by name from the registry.""" tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_name) - if tool is not None: - return tool - msg = f'Unable to resolve tool {tool_name}' - raise ValueError(msg) + if tool is None: + raise ValueError(f'Unable to resolve tool {tool_name}') + return tool async def _resolve_resume_options( @@ -761,8 +760,8 @@ async def _resolve_resume_options( continue resumed_request, resumed_response = await _resolve_resumed_tool_request(_registry, raw_request, part) - tool_responses.append(resumed_response) - updated_content[i] = resumed_request + tool_responses.append(Part(root=resumed_response)) + updated_content[i] = Part(root=resumed_request) i += 1 last_message.content = updated_content @@ -787,7 +786,7 @@ async def _resolve_resume_options( async def _resolve_resumed_tool_request( registry: Registry, raw_request: GenerateActionOptions, tool_request_part: Part -) -> tuple[Part, Part]: +) -> tuple[ToolRequestPart, ToolResponsePart]: """Resolve a single tool request from pending output, resume.respond, or resume.restart.""" # Type narrowing: ensure we're working with a ToolRequestPart if not isinstance(tool_request_part.root, ToolRequestPart): @@ -804,17 +803,14 @@ async def _resolve_resumed_tool_request( del metadata['pendingOutput'] metadata['source'] = 'pending' return ( - tool_request_part, - # Part is a RootModel, so we pass content via 'root' parameter - Part( - root=ToolResponsePart( - tool_response=ToolResponse( - name=tool_req_root.tool_request.name, - ref=tool_req_root.tool_request.ref, - output=pending_output.model_dump() if isinstance(pending_output, BaseModel) else pending_output, - ), - metadata=metadata, - ) + tool_req_root, + ToolResponsePart( + tool_response=ToolResponse( + name=tool_req_root.tool_request.name, + ref=tool_req_root.tool_request.ref, + output=pending_output.model_dump() if isinstance(pending_output, BaseModel) else pending_output, + ), + metadata=metadata, ), ) @@ -830,16 +826,13 @@ async def _resolve_resumed_tool_request( if interrupt: del metadata['interrupt'] return ( - # Part is a RootModel, so we pass content via 'root' parameter - Part( - root=ToolRequestPart( - tool_request=ToolRequest( - name=tool_req_root.tool_request.name, - ref=tool_req_root.tool_request.ref, - input=tool_req_root.tool_request.input, - ), - metadata={**metadata, 'resolvedInterrupt': interrupt}, - ) + ToolRequestPart( + tool_request=ToolRequest( + name=tool_req_root.tool_request.name, + ref=tool_req_root.tool_request.ref, + input=tool_req_root.tool_request.input, + ), + metadata={**metadata, 'resolvedInterrupt': interrupt}, ), provided_response, ) @@ -856,15 +849,13 @@ async def _resolve_resumed_tool_request( if interrupt: del metadata['interrupt'] return ( - Part( - root=ToolRequestPart( - tool_request=ToolRequest( - name=tool_req_root.tool_request.name, - ref=tool_req_root.tool_request.ref, - input=tool_req_root.tool_request.input, - ), - metadata={**metadata, 'resolvedInterrupt': interrupt}, - ) + ToolRequestPart( + tool_request=ToolRequest( + name=tool_req_root.tool_request.name, + ref=tool_req_root.tool_request.ref, + input=tool_req_root.tool_request.input, + ), + metadata={**metadata, 'resolvedInterrupt': interrupt}, ), executed, ) @@ -890,11 +881,11 @@ def _find_corresponding_restart( return None -def _find_corresponding_tool_response(responses: list[ToolResponsePart], request: ToolRequestPart) -> Part | None: +def _find_corresponding_tool_response(responses: list[ToolResponsePart], request: ToolRequestPart) -> ToolResponsePart | None: """Find a response matching the request by name and ref.""" for p in responses: if p.tool_response.name == request.tool_request.name and p.tool_response.ref == request.tool_request.ref: - return Part(root=p) + return p return None diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index 7bcfef00c3..a22437959a 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -109,7 +109,7 @@ def resume_options_to_resume( resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, resume_metadata: dict[str, Any] | None = None, ) -> Resume | None: - """Build wire :class:`Resume` from flat keyword options (``generate`` / prompts).""" + """Build wire Resume from flat keyword options (``generate`` / prompts).""" respond = _normalize_resume_respond_parts(resume_respond) restart = _normalize_resume_restart_parts(resume_restart) if respond is None and restart is None and resume_metadata is None: @@ -542,7 +542,7 @@ async def render( ) -> GenerateActionOptions: """Render the prompt template without executing, returning GenerateActionOptions. - Same keyword options as :meth:`__call__` (see :class:`PromptGenerateOptions`). + Same keyword options as ``__call__`` (see PromptGenerateOptions). """ await self._ensure_resolved() coerced = _coerce_prompt_opts(opts) diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index a42b33e779..a96348673a 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -26,7 +26,7 @@ from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._error import GenkitError from genkit._core._registry import Registry -from genkit._core._typing import Part, ToolDefinition, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart +from genkit._core._typing import ToolDefinition, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart class Tool: @@ -97,7 +97,7 @@ def restart( replace_input: Optional new ``tool_request.input`` for this run (previous input is stored in ``metadata.replacedInput`` when this is set). interrupt: The interrupted ``ToolRequestPart`` (e.g. from ``response.interrupts``). - resumed_metadata: Passed to the tool as :attr:`ToolRunContext.resumed_metadata`. + resumed_metadata: Passed to the tool as ``ToolRunContext.resumed_metadata``. Returns: A ``ToolRequestPart`` for ``resume_restart`` / message history. @@ -203,12 +203,12 @@ class Interrupt(Exception): # noqa: N818 - public Genkit name; not renamed *Err """Exception for interrupting tool execution with user-facing API. Raise ``Interrupt(data)`` from a tool or from tool middleware (e.g. ``wrap_tool``). - Exceptions from ``tool.run`` are wrapped in :class:`~genkit._core._error.GenkitError` + Exceptions from ``tool.run`` are wrapped in GenkitError with ``cause=Interrupt``; generation attaches interrupt metadata to the pending tool request. - To resume, use :func:`respond_to_interrupt` or ``tool.restart(...)`` on the - registered :class:`Tool`. + To resume, use ``respond_to_interrupt`` or ``tool.restart(...)`` on the + registered Tool. """ def __init__(self, data: dict[str, Any] | None = None) -> None: @@ -258,11 +258,11 @@ def respond_to_interrupt( return _tool_response_part(interrupt, response, metadata) -async def run_tool_after_restart(tool: Action[Any, Any, Any], restart_trp: ToolRequestPart) -> Part: +async def run_tool_after_restart(tool: Action[Any, Any, Any], restart_trp: ToolRequestPart) -> ToolResponsePart: """Run a tool for ``resume_restart``: applies ``resumed`` / ``replacedInput`` from metadata. - Sets the same context variables as the tool wrapper so :class:`ToolRunContext` reflects - a resumed run. Nested interrupts during restart are not supported and raise :class:`GenkitError`. + Sets the same context variables as the tool wrapper so ToolRunContext reflects + a resumed run. Nested interrupts during restart are not supported and raise GenkitError. """ meta = restart_trp.metadata or {} raw_resumed = meta.get('resumed') @@ -295,13 +295,11 @@ async def run_tool_after_restart(tool: Action[Any, Any, Any], restart_trp: ToolR _tool_resumed_metadata.reset(token_meta) _tool_original_input.reset(token_input) - return Part( - root=ToolResponsePart( - tool_response=ToolResponse( - name=restart_trp.tool_request.name, - ref=restart_trp.tool_request.ref, - output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, - ) + return ToolResponsePart( + tool_response=ToolResponse( + name=restart_trp.tool_request.name, + ref=restart_trp.tool_request.ref, + output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, ) ) @@ -407,10 +405,10 @@ def define_interrupt( ) -> Tool: """Register a tool that always interrupts execution. - An interrupt tool is a special tool that always raises :class:`Interrupt` with + An interrupt tool is a special tool that always raises ``Interrupt`` with optional metadata. This is useful for explicit human-in-the-loop checkpoints. - For tools that sometimes run logic and sometimes interrupt, use :func:`define_tool` - and raise :class:`Interrupt` from the handler (or use ``ToolRunContext``). + For tools that sometimes run logic and sometimes interrupt, use ``define_tool`` + and raise ``Interrupt`` from the handler (or use ``ToolRunContext``). Args: registry: The registry to register the interrupt tool in @@ -421,7 +419,7 @@ def define_interrupt( interrupt handler is typed as ``Any``; pass this so the model sees a concrete shape. Returns: - The registered tool callable (same shape as :func:`define_tool`). + The registered tool callable (same shape as ``define_tool``). Example: def get_meta(input: dict) -> dict: diff --git a/py/packages/genkit/tests/genkit/ai/_tools_test.py b/py/packages/genkit/tests/genkit/ai/_tools_test.py index 4d02e742dc..bed5c05852 100644 --- a/py/packages/genkit/tests/genkit/ai/_tools_test.py +++ b/py/packages/genkit/tests/genkit/ai/_tools_test.py @@ -188,8 +188,8 @@ async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 ) with pytest.raises(GenkitError) as ei: await run_tool_after_restart(action, restart_trp) - msg = ei.value.original_message.lower() - assert 'restart' in msg or 'nested' in msg or 'not supported' in msg + assert ei.value.status == 'FAILED_PRECONDITION' + assert 'interrupted again' in ei.value.original_message.lower() # --------------------------------------------------------------------------- @@ -263,5 +263,37 @@ async def t_ref(inp: dict) -> str: # noqa: ARG001 metadata={'resumed': True}, ) part = await run_tool_after_restart(action, restart_trp) - assert part.root.tool_response.ref == 'wire-ref-99' + assert part.tool_response.ref == 'wire-ref-99' + +@pytest.mark.asyncio +async def test_run_tool_after_restart_response_preserves_ref_and_uses_new_input() -> None: + """``run_tool_after_restart`` returns a ToolResponsePart whose ref matches the restart TRP, + and the tool receives the replaced input (not the original interrupted input). + """ + ai = Genkit() + received_inputs: list[dict] = [] + + @ai.tool(name='transfer') + async def transfer(inp: dict) -> str: + received_inputs.append(dict(inp)) + if not inp.get('confirmed'): + raise Interrupt({'reason': 'needs_approval'}) + return f"transferred {inp.get('amount')}" + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='transfer') + assert action is not None + + # Simulate a restart TRP: original input had confirmed=False, new input has confirmed=True. + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='transfer', ref='ref-42', input={'amount': 100, 'confirmed': True}), + metadata={'resumed': True, 'replacedInput': {'amount': 100, 'confirmed': False}}, + ) + result = await run_tool_after_restart(action, restart_trp) + + # Ref is preserved from the restart TRP. + assert result.tool_response.ref == 'ref-42' + assert result.tool_response.name == 'transfer' + # Tool received the new (replaced) input, not the original. + assert received_inputs == [{'amount': 100, 'confirmed': True}] + assert result.tool_response.output == 'transferred 100' diff --git a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py index e05b273328..78c029c27a 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py @@ -10,12 +10,12 @@ import pytest from genkit import Genkit, Message, ModelResponse -from genkit._ai._generate import _resolve_resumed_tool_request, generate_action +from genkit._ai._generate import generate_action from genkit._ai._testing import define_programmable_model from genkit._ai._tools import Interrupt, ToolRunContext, respond_to_interrupt from genkit._core._error import GenkitError from genkit._core._model import GenerateActionOptions -from genkit._core._typing import FinishReason, Part, Resume +from genkit._core._typing import FinishReason, Resume def _wire(messages: list[Message]) -> list[dict[str, Any]]: @@ -253,9 +253,11 @@ async def intr(_: dict) -> str: # noqa: ARG001 @pytest.mark.asyncio async def test_tool_either_interrupts_or_returns() -> None: - """Same ``bank_transfer`` tool: ``preapproved=False`` interrupts (short history), - ``preapproved=True`` runs through to a normal tool response and a final model turn. Two - back-to-back runs with different queued model responses, each wire asserted in full. + """Same tool, two independent generate calls with different inputs. + + First call: ``preapproved=False`` → tool raises Interrupt → finish is INTERRUPTED, no tool row. + Second call: ``preapproved=True`` → tool returns 42 → finish is STOP, full tool+model rows present. + Both results are wire-asserted in full. """ ai = Genkit() pm, _ = define_programmable_model(ai) @@ -413,9 +415,9 @@ async def pay(inp: dict) -> str: message=Message.model_validate({'role': 'model', 'content': [{'text': 'final'}]}), ) ) + # ^ Queued for the second generate call (after restart re-runs the tool). first = await generate_action( - ai.registry, _gen_opts(ai, tools=['pay'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), ) assert first.finish_reason == FinishReason.INTERRUPTED @@ -584,6 +586,7 @@ async def b_tool(inp: dict) -> str: 'toolResponse': {'ref': 'ra', 'name': 'a', 'output': {'done': True}}, 'metadata': {'interruptResponse': True}, }, + # Restart path: tool re-runs and returns its own output; no interruptResponse metadata. {'toolResponse': {'ref': 'rb', 'name': 'b', 'output': 'b-done'}}, ], 'metadata': {'resumed': True}, @@ -596,28 +599,112 @@ async def b_tool(inp: dict) -> str: @pytest.mark.asyncio -async def test_pending_output_trp_yields_tool_response_with_source_pending() -> None: - """If the TRP only has ``pendingOutput``, ``_resolve_resumed_tool_request`` should synthesize a - tool response with that output and ``metadata.source`` set to ``pending``—checked on the part - directly, not a full message list. +async def test_mixed_one_interrupts_one_succeeds_pending_output_in_wire() -> None: + """Two tools in one turn: ``a`` interrupts, ``b`` succeeds. + + Turn 1: both run in parallel. ``b``'s output is stashed as ``pendingOutput`` + on its TRP in the model message (no tool message yet). finish=INTERRUPTED. + + Turn 2: resume with ``respond=[...]`` for ``a`` only — no action needed for + ``b``. The framework reconstructs ``b``'s tool response from the stashed + output and marks it ``source: pending`` in the wire. """ ai = Genkit() - part = Part.model_validate({ - 'toolRequest': {'name': 't', 'ref': 'r', 'input': {}}, - 'metadata': {'pendingOutput': 123}, - }) - _, resp = await _resolve_resumed_tool_request( + pm, _ = define_programmable_model(ai) + + @ai.tool(name='a') + async def a_tool(_: dict) -> str: # noqa: ARG001 + raise Interrupt({'reason': 'needs_approval'}) + + @ai.tool(name='b') + async def b_tool(_: dict) -> int: # noqa: ARG001 + return 42 + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}}, + {'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}}, + ], + }), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'done'}]}), + ) + ) + + first = await generate_action( + ai.registry, + _gen_opts(ai, tools=['a', 'b'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + ) + assert first.finish_reason == FinishReason.INTERRUPTED + # b's output is stashed in pendingOutput on its TRP; no tool message yet. + assert _wire(first.messages) == [ + {'role': 'user', 'content': [{'text': 'hi'}]}, + { + 'role': 'model', + 'content': [ + { + 'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}, + 'metadata': {'interrupt': {'reason': 'needs_approval'}}, + }, + { + 'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}, + 'metadata': {'pendingOutput': 42}, + }, + ], + }, + ] + + ia = first.interrupts[0] + second = await generate_action( ai.registry, - GenerateActionOptions( - model='programmableModel', - messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'x'}]})], + _gen_opts( + ai, + tools=['a', 'b'], + messages=list(first.messages), + resume=Resume(respond=[respond_to_interrupt({'approved': True}, interrupt=ia)]), ), - part, ) - assert resp.model_dump(mode='json', exclude_none=True, by_alias=True) == { - 'toolResponse': {'ref': 'r', 'name': 't', 'output': 123}, - 'metadata': {'source': 'pending'}, - } + + assert second.finish_reason == FinishReason.STOP + assert _wire(second.messages) == [ + {'role': 'user', 'content': [{'text': 'hi'}]}, + { + 'role': 'model', + 'content': [ + { + 'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'reason': 'needs_approval'}}, + }, + { + 'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}, + }, + ], + }, + { + 'role': 'tool', + 'content': [ + { + 'toolResponse': {'ref': 'ra', 'name': 'a', 'output': {'approved': True}}, + 'metadata': {'interruptResponse': True}, + }, + { + # b ran on turn 1; output reconstructed from pendingOutput stash. + 'toolResponse': {'ref': 'rb', 'name': 'b', 'output': 42}, + 'metadata': {'source': 'pending'}, + }, + ], + 'metadata': {'resumed': True}, + }, + {'role': 'model', 'content': [{'text': 'done'}]}, + ] @pytest.mark.asyncio @@ -648,6 +735,7 @@ async def test_resume_without_matching_replies_raises() -> None: resume=Resume(), ), ) + assert ei.value.status == 'INVALID_ARGUMENT' assert 'replies' in ei.value.original_message or 'restarts' in ei.value.original_message @@ -668,4 +756,5 @@ async def test_resume_requires_last_message_model_with_tool_requests() -> None: resume=Resume(), ), ) + assert ei.value.status == 'FAILED_PRECONDITION' assert 'model' in ei.value.original_message.lower() diff --git a/py/packages/genkit/tests/genkit/ai/prompt_test.py b/py/packages/genkit/tests/genkit/ai/prompt_test.py index 8a62fd47da..1fbb7c3226 100644 --- a/py/packages/genkit/tests/genkit/ai/prompt_test.py +++ b/py/packages/genkit/tests/genkit/ai/prompt_test.py @@ -19,7 +19,7 @@ import tempfile from pathlib import Path -from typing import Any, cast +from typing import Any from unittest.mock import ANY, MagicMock, patch import pytest @@ -605,7 +605,7 @@ async def test_executable_prompt_opts_removed() -> None: my_prompt = ai.define_prompt(prompt='hi', output_format='text') with pytest.raises(TypeError, match='opts'): - await my_prompt(**cast(Any, {'opts': {'model': 'echoModel'}})) + await my_prompt(**{'opts': {'model': 'echoModel'}}) @pytest.mark.asyncio diff --git a/py/samples/tool-interrupts/src/approval_example.py b/py/samples/tool-interrupts/src/approval_example.py index 4d10022c07..220ed58264 100644 --- a/py/samples/tool-interrupts/src/approval_example.py +++ b/py/samples/tool-interrupts/src/approval_example.py @@ -18,7 +18,7 @@ The model calls ``request_transfer``; the CLI asks **approve (y)** or **decline (n)**. **Approve** → ``tool.restart(...)`` / ``resume_restart`` so the tool **runs again** -with :class:`~genkit.ToolRunContext.is_resumed`. **Decline** → ``respond_to_interrupt`` / +with ``ToolRunContext.is_resumed``. **Decline** → ``respond_to_interrupt`` / ``resume_respond`` (no second tool run). Opening prompt, then one **canned user message** (no typing) so the model calls diff --git a/py/samples/tool-interrupts/src/respond_example.py b/py/samples/tool-interrupts/src/respond_example.py index 5a3a9a979d..f709f4bff9 100755 --- a/py/samples/tool-interrupts/src/respond_example.py +++ b/py/samples/tool-interrupts/src/respond_example.py @@ -16,7 +16,7 @@ """Tool interrupts — trivia via ``present_questions`` and ``respond_to_interrupt``. -``present_questions`` raises :class:`~genkit.Interrupt` with the question payload → the user +``present_questions`` raises ``Interrupt`` with the question payload → the user picks an answer → ``respond_to_interrupt(pick, interrupt=…, metadata=…)`` → second ``generate`` with ``resume_respond``. @@ -145,7 +145,6 @@ def show(label: str, r: ModelResponse) -> None: interrupt=interrupt, metadata={'source': 'cli', 'path': 'respond'}, ) - assert isinstance(interrupt_response, ToolResponsePart) response = await ai.generate( messages=messages, resume_respond=[interrupt_response], From 1ea8a7586ab043937bcc5131f02c37665b8a8ac0 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 31 Mar 2026 00:36:14 -0700 Subject: [PATCH 12/15] fix --- .../genkit/src/genkit/_ai/_generate.py | 20 ++++++++++--------- py/packages/genkit/src/genkit/_ai/_tools.py | 17 +--------------- .../genkit/tests/genkit/ai/_tools_test.py | 8 +++++--- .../ai/generate_interrupt_resume_test.py | 12 +++++------ .../plugins/google_genai/models/gemini.py | 11 +--------- 5 files changed, 24 insertions(+), 44 deletions(-) diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index e3af426c5e..149ffc9c91 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -16,11 +16,8 @@ """Generate action.""" -<<<<<<< HEAD import asyncio -======= import contextlib ->>>>>>> main import copy import inspect import re @@ -848,19 +845,24 @@ async def _resolve_resumed_tool_request( tool_req_root = tool_request_part.root if tool_req_root.metadata and 'pendingOutput' in tool_req_root.metadata: - metadata = dict(tool_req_root.metadata) - pending_output = metadata['pendingOutput'] - del metadata['pendingOutput'] - metadata['source'] = 'pending' + # resolveResumedToolRequest: strip pendingOutput from the model TRP; reconstruct + # output on the tool message with metadata { ...rest, source: 'pending' }. + trp_metadata = dict(tool_req_root.metadata) + pending_output = trp_metadata.pop('pendingOutput') + revised_trp = ToolRequestPart( + tool_request=tool_req_root.tool_request, + metadata=trp_metadata if trp_metadata else None, + ) + response_metadata = {**trp_metadata, 'source': 'pending'} return ( - tool_req_root, + revised_trp, ToolResponsePart( tool_response=ToolResponse( name=tool_req_root.tool_request.name, ref=tool_req_root.tool_request.ref, output=pending_output.model_dump() if isinstance(pending_output, BaseModel) else pending_output, ), - metadata=metadata, + metadata=response_metadata, ), ) diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index a96348673a..3dc4be7732 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -39,10 +39,6 @@ class Tool: def __init__(self, action: Action) -> None: self._action = action - # ------------------------------------------------------------------ - # Properties that delegate to the underlying Action - # ------------------------------------------------------------------ - @property def name(self) -> str: """Tool name (registry key).""" @@ -64,7 +60,7 @@ def output_schema(self) -> dict[str, object] | None: return self._action.output_schema def definition(self) -> ToolDefinition: - """Return the wire-format :class:`ToolDefinition` for this tool.""" + """Return the wire-format ToolDefinition for this tool.""" return ToolDefinition( name=self.name, description=self.description, @@ -72,18 +68,10 @@ def definition(self) -> ToolDefinition: output_schema=self.output_schema, ) - # ------------------------------------------------------------------ - # Execution - # ------------------------------------------------------------------ - async def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 """Run the tool and return the unwrapped response value.""" return (await self._action.run(*args, **kwargs)).response - # ------------------------------------------------------------------ - # Interrupt / restart helpers - # ------------------------------------------------------------------ - def restart( self, replace_input: Any | None = None, # noqa: ANN401 @@ -120,9 +108,6 @@ def restart( new_meta['replacedInput'] = tool_req.input new_input = replace_input - if 'interrupt' in new_meta: - del new_meta['interrupt'] - return ToolRequestPart( tool_request=ToolRequest( name=tool_req.name, diff --git a/py/packages/genkit/tests/genkit/ai/_tools_test.py b/py/packages/genkit/tests/genkit/ai/_tools_test.py index bed5c05852..38f9ab0203 100644 --- a/py/packages/genkit/tests/genkit/ai/_tools_test.py +++ b/py/packages/genkit/tests/genkit/ai/_tools_test.py @@ -19,8 +19,8 @@ @pytest.mark.asyncio -async def test_restart_sets_resumed_metadata_and_strips_interrupt() -> None: - """``tool.restart`` → TRP.metadata.resumed plus ``interrupt`` removed; input unchanged when replace is None.""" +async def test_restart_sets_resumed_metadata_and_preserves_interrupt() -> None: + """``tool.restart``: copy interrupt metadata, set ``resumed``; ``interrupt`` stays on the restart TRP.""" ai = Genkit() @ai.tool(name='pay') @@ -35,7 +35,7 @@ async def pay(inp: dict) -> str: # noqa: ARG001 assert isinstance(out, ToolRequestPart) assert out.metadata is not None assert out.metadata.get('resumed') == {'k': 'v'} - assert 'interrupt' not in out.metadata + assert out.metadata.get('interrupt') == {'reason': 'hold'} assert out.tool_request.input == {'amount': 10} @@ -58,6 +58,7 @@ async def pay(inp: dict) -> str: # noqa: ARG001 assert out.metadata.get('replacedInput') == {'amount': 10} assert out.tool_request.input == {'amount': 99} assert out.metadata.get('resumed') == {'by': 'u'} + assert out.metadata.get('interrupt') is True @pytest.mark.asyncio @@ -77,6 +78,7 @@ async def pay(inp: dict) -> str: # noqa: ARG001 assert isinstance(out, ToolRequestPart) assert out.metadata is not None assert out.metadata.get('resumed') is True + assert out.metadata.get('interrupt') is True @pytest.mark.asyncio diff --git a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py index 78c029c27a..ff7a553ad5 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py @@ -418,6 +418,7 @@ async def pay(inp: dict) -> str: # ^ Queued for the second generate call (after restart re-runs the tool). first = await generate_action( + ai.registry, _gen_opts(ai, tools=['pay'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), ) assert first.finish_reason == FinishReason.INTERRUPTED @@ -607,7 +608,8 @@ async def test_mixed_one_interrupts_one_succeeds_pending_output_in_wire() -> Non Turn 2: resume with ``respond=[...]`` for ``a`` only — no action needed for ``b``. The framework reconstructs ``b``'s tool response from the stashed - output and marks it ``source: pending`` in the wire. + output, strips ``pendingOutput`` from ``b``'s model TRP, and + marks the tool response ``source: pending`` on the wire. """ ai = Genkit() pm, _ = define_programmable_model(ai) @@ -683,9 +685,7 @@ async def b_tool(_: dict) -> int: # noqa: ARG001 'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}, 'metadata': {'resolvedInterrupt': {'reason': 'needs_approval'}}, }, - { - 'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}, - }, + {'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}}, ], }, { @@ -736,7 +736,7 @@ async def test_resume_without_matching_replies_raises() -> None: ), ) assert ei.value.status == 'INVALID_ARGUMENT' - assert 'replies' in ei.value.original_message or 'restarts' in ei.value.original_message + assert 'unresolved tool request' in ei.value.original_message.lower() @pytest.mark.asyncio @@ -757,4 +757,4 @@ async def test_resume_requires_last_message_model_with_tool_requests() -> None: ), ) assert ei.value.status == 'FAILED_PRECONDITION' - assert 'model' in ei.value.original_message.lower() + assert "cannot 'resume'" in ei.value.original_message.lower() diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index 4d7424784a..0f97947aca 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -1532,15 +1532,6 @@ async def _streaming_generate( elif e.code == 429: status = 'RESOURCE_EXHAUSTED' -<<<<<<< HEAD - raise GenkitError( - status=status, - message=e.message or 'Unknown error', - cause=e, - ) from e - accumulated_content = [] - finish_reason: FinishReason | None = None -======= raise GenkitError( status=status, message=e.message or 'Unknown error', @@ -1548,7 +1539,7 @@ async def _streaming_generate( ) from e accumulated_content: list[Part] = [] ->>>>>>> main + finish_reason: FinishReason | None = None async for response_chunk in generator: content = await self._contents_from_response(response_chunk) if content: # Only process if we have content From 105189f069e933834fc2c039ba3b062366add465 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 31 Mar 2026 09:22:07 -0700 Subject: [PATCH 13/15] polish and revert fastapi change --- py/packages/genkit/tests/genkit/ai/_tools_test.py | 15 ++++++++++----- py/plugins/fastapi/pyproject.toml | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/py/packages/genkit/tests/genkit/ai/_tools_test.py b/py/packages/genkit/tests/genkit/ai/_tools_test.py index 38f9ab0203..715cee8dab 100644 --- a/py/packages/genkit/tests/genkit/ai/_tools_test.py +++ b/py/packages/genkit/tests/genkit/ai/_tools_test.py @@ -270,15 +270,18 @@ async def t_ref(inp: dict) -> str: # noqa: ARG001 @pytest.mark.asyncio async def test_run_tool_after_restart_response_preserves_ref_and_uses_new_input() -> None: - """``run_tool_after_restart`` returns a ToolResponsePart whose ref matches the restart TRP, - and the tool receives the replaced input (not the original interrupted input). + """``run_tool_after_restart`` returns a ToolResponsePart whose ref matches the restart TRP; + ``tool_request.input`` is what ``tool.run`` receives, and ``metadata.replacedInput`` is + ``ToolRunContext.original_input`` (prior interrupted input). """ ai = Genkit() received_inputs: list[dict] = [] + original_inputs: list[object | None] = [] @ai.tool(name='transfer') - async def transfer(inp: dict) -> str: + async def transfer(inp: dict, ctx: ToolRunContext) -> str: received_inputs.append(dict(inp)) + original_inputs.append(ctx.original_input) if not inp.get('confirmed'): raise Interrupt({'reason': 'needs_approval'}) return f"transferred {inp.get('amount')}" @@ -286,16 +289,18 @@ async def transfer(inp: dict) -> str: action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='transfer') assert action is not None + prior = {'amount': 100, 'confirmed': False} # Simulate a restart TRP: original input had confirmed=False, new input has confirmed=True. restart_trp = ToolRequestPart( tool_request=ToolRequest(name='transfer', ref='ref-42', input={'amount': 100, 'confirmed': True}), - metadata={'resumed': True, 'replacedInput': {'amount': 100, 'confirmed': False}}, + metadata={'resumed': True, 'replacedInput': prior}, ) result = await run_tool_after_restart(action, restart_trp) # Ref is preserved from the restart TRP. assert result.tool_response.ref == 'ref-42' assert result.tool_response.name == 'transfer' - # Tool received the new (replaced) input, not the original. + # Primary arg is current tool_request.input; replacedInput is surfaced as original_input. assert received_inputs == [{'amount': 100, 'confirmed': True}] + assert original_inputs == [prior] assert result.tool_response.output == 'transferred 100' diff --git a/py/plugins/fastapi/pyproject.toml b/py/plugins/fastapi/pyproject.toml index 9438a86c41..d55139a55f 100644 --- a/py/plugins/fastapi/pyproject.toml +++ b/py/plugins/fastapi/pyproject.toml @@ -66,4 +66,4 @@ build-backend = "hatchling.build" requires = ["hatchling"] [tool.hatch.build.targets.wheel] -only-include = ["src/genkit/plugins/fastapi"] +packages = ["src/genkit", "src/genkit/plugins"] From 0fd00fa4d058dbcea567edf8e8b6b55ad0da7c5d Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 31 Mar 2026 10:01:39 -0700 Subject: [PATCH 14/15] fix lint --- py/packages/genkit/src/genkit/_ai/_generate.py | 4 +++- py/packages/genkit/tests/genkit/ai/_tools_test.py | 2 +- py/packages/genkit/tests/genkit/ai/generate_helpers_test.py | 2 +- .../genkit/tests/genkit/ai/generate_interrupt_resume_test.py | 4 +++- py/packages/genkit/tests/genkit/ai/prompt_test.py | 3 ++- py/samples/tool-interrupts/src/respond_example.py | 1 - 6 files changed, 10 insertions(+), 6 deletions(-) diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index 149ffc9c91..df29295e47 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -933,7 +933,9 @@ def _find_corresponding_restart( return None -def _find_corresponding_tool_response(responses: list[ToolResponsePart], request: ToolRequestPart) -> ToolResponsePart | None: +def _find_corresponding_tool_response( + responses: list[ToolResponsePart], request: ToolRequestPart +) -> ToolResponsePart | None: """Find a response matching the request by name and ref.""" for p in responses: if p.tool_response.name == request.tool_request.name and p.tool_response.ref == request.tool_request.ref: diff --git a/py/packages/genkit/tests/genkit/ai/_tools_test.py b/py/packages/genkit/tests/genkit/ai/_tools_test.py index 715cee8dab..e5228931ba 100644 --- a/py/packages/genkit/tests/genkit/ai/_tools_test.py +++ b/py/packages/genkit/tests/genkit/ai/_tools_test.py @@ -284,7 +284,7 @@ async def transfer(inp: dict, ctx: ToolRunContext) -> str: original_inputs.append(ctx.original_input) if not inp.get('confirmed'): raise Interrupt({'reason': 'needs_approval'}) - return f"transferred {inp.get('amount')}" + return f'transferred {inp.get("amount")}' action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='transfer') assert action is not None diff --git a/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py b/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py index dca72340b2..60c9ada561 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py @@ -54,7 +54,7 @@ def test_find_corresponding_tool_response_matches_name_and_ref() -> None: got = _find_corresponding_tool_response([trp], pending) assert got is not None - assert got.root == trp + assert got == trp assert _find_corresponding_tool_response([other], pending) is None assert _find_corresponding_tool_response([], pending) is None diff --git a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py index ff7a553ad5..26bcfe5af8 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py @@ -643,7 +643,9 @@ async def b_tool(_: dict) -> int: # noqa: ARG001 first = await generate_action( ai.registry, - _gen_opts(ai, tools=['a', 'b'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + _gen_opts( + ai, tools=['a', 'b'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})] + ), ) assert first.finish_reason == FinishReason.INTERRUPTED # b's output is stashed in pendingOutput on its TRP; no tool message yet. diff --git a/py/packages/genkit/tests/genkit/ai/prompt_test.py b/py/packages/genkit/tests/genkit/ai/prompt_test.py index 1fbb7c3226..4ea4194580 100644 --- a/py/packages/genkit/tests/genkit/ai/prompt_test.py +++ b/py/packages/genkit/tests/genkit/ai/prompt_test.py @@ -605,7 +605,8 @@ async def test_executable_prompt_opts_removed() -> None: my_prompt = ai.define_prompt(prompt='hi', output_format='text') with pytest.raises(TypeError, match='opts'): - await my_prompt(**{'opts': {'model': 'echoModel'}}) + # Invalid kwarg on purpose; static checkers need a hint (runtime rejects with TypeError). + await my_prompt(opts={'model': 'echoModel'}) # pyrefly: ignore[unexpected-keyword] # pyright: ignore @pytest.mark.asyncio diff --git a/py/samples/tool-interrupts/src/respond_example.py b/py/samples/tool-interrupts/src/respond_example.py index f709f4bff9..741929fd99 100755 --- a/py/samples/tool-interrupts/src/respond_example.py +++ b/py/samples/tool-interrupts/src/respond_example.py @@ -32,7 +32,6 @@ from genkit import ( Genkit, Interrupt, - ToolResponsePart, respond_to_interrupt, ) from genkit.model import ModelResponse From e2f48e4abfc00c55943ad172750b22e9e87e9ad3 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 31 Mar 2026 10:18:04 -0700 Subject: [PATCH 15/15] fix test --- .../genkit/tests/genkit/veneer/veneer_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py index 307fb36c0d..f3e08af8e9 100644 --- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py @@ -556,11 +556,11 @@ async def test_interrupt(input: ToolInput) -> None: assert interrupted_response.messages == [ Message( - role='user', + role=Role.USER, content=[Part(root=TextPart(text='hi'))], ), Message( - role='model', + role=Role.MODEL, content=[ Part(root=TextPart(text='call these tools')), Part( @@ -592,11 +592,11 @@ async def test_interrupt(input: ToolInput) -> None: assert response.messages == [ Message( - role='user', + role=Role.USER, content=[Part(root=TextPart(text='hi'))], ), Message( - role='model', + role=Role.MODEL, content=[ Part(root=TextPart(text='call these tools')), Part( @@ -608,14 +608,14 @@ async def test_interrupt(input: ToolInput) -> None: Part( root=ToolRequestPart( tool_request=ToolRequest(ref='234', name='test_tool', input={'value': 5}), - metadata={'pendingOutput': 12}, + metadata=None, ) ), ], metadata=None, ), Message( - role='tool', + role=Role.TOOL, content=[ Part( root=ToolResponsePart( @@ -633,7 +633,7 @@ async def test_interrupt(input: ToolInput) -> None: metadata={'resumed': True}, ), Message( - role='model', + role=Role.MODEL, content=[Part(root=TextPart(text='tool called'))], metadata=None, ),