diff --git a/py/README.md b/py/README.md index a54a194bb0..0ddb09f373 100644 --- a/py/README.md +++ b/py/README.md @@ -424,7 +424,7 @@ The Python samples are intentionally short and beginner-friendly. Use this table | `middleware` | Middleware | Observe or modify model requests with `use=[...]` | | `output-formats` | Structured Output | Text, enum, JSON object, array, and JSONL outputs | | `prompts` | Prompt, Dotprompt | Use `.prompt` files, helpers, variants, and streaming | -| `tool-interrupts` | Tool | Pause tool execution for human approval | +| `tool-interrupts` | Tool | Trivia (`respond_example.py`) and bank approval (`approval_example.py`) for interrupt / resume | | `tracing` | Tracing | Watch spans appear in the Dev UI as they start | | `vertexai-imagen` | Vertex AI, Image Generation | Generate an image with Vertex AI Imagen | diff --git a/py/docs/index.md b/py/docs/index.md index 3023418f48..f7c2618977 100644 --- a/py/docs/index.md +++ b/py/docs/index.md @@ -35,7 +35,7 @@ ::: genkit.PublicError -::: genkit.ToolInterruptError +::: genkit.Interrupt ::: genkit.Message diff --git a/py/docs/types.md b/py/docs/types.md index 823b4c2259..239d1faf94 100644 --- a/py/docs/types.md +++ b/py/docs/types.md @@ -32,7 +32,7 @@ Types exported from genkit, genkit.model, genkit.embedder, genkit.plugin_api, an ::: genkit.PublicError -::: genkit.ToolInterruptError +::: genkit.Interrupt ::: genkit.Message diff --git a/py/packages/genkit/src/genkit/__init__.py b/py/packages/genkit/src/genkit/__init__.py index 2af08aa319..0a4c5ab28d 100644 --- a/py/packages/genkit/src/genkit/__init__.py +++ b/py/packages/genkit/src/genkit/__init__.py @@ -21,9 +21,13 @@ ExecutablePrompt, ModelStreamResponse, PromptGenerateOptions, - ResumeOptions, ) -from genkit._ai._tools import ToolInterruptError, ToolRunContext, tool_response +from genkit._ai._tools import ( + Interrupt, + Tool, + ToolRunContext, + respond_to_interrupt, +) from genkit._core._action import Action, StreamResponse from genkit._core._error import GenkitError, PublicError from genkit._core._model import Document @@ -93,7 +97,9 @@ # Errors 'GenkitError', 'PublicError', - 'ToolInterruptError', + 'Interrupt', + 'Tool', + 'respond_to_interrupt', # Content types 'Constrained', 'CustomPart', @@ -126,9 +132,7 @@ 'ActionRunContext', 'ExecutablePrompt', 'PromptGenerateOptions', - 'ResumeOptions', 'ToolRunContext', - 'tool_response', 'ModelRequest', 'ModelResponse', 'ModelResponseChunk', diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index 25af65fb11..b0afc530e3 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -25,9 +25,9 @@ import socket import threading import uuid -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import Awaitable, Callable, Coroutine, Sequence from pathlib import Path -from typing import Any, ParamSpec, TypeVar, cast, overload +from typing import Any, TypeVar, cast, overload import anyio import uvicorn @@ -71,7 +71,7 @@ ResourceOptions, define_resource, ) -from genkit._ai._tools import define_tool +from genkit._ai._tools import Tool, define_interrupt, define_tool from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._background import ( BackgroundAction, @@ -107,6 +107,8 @@ Part, SpanMetadata, ToolChoice, + ToolRequestPart, + ToolResponsePart, ) from ._decorators import _FlowDecorator, _FlowDecoratorWithChunk @@ -118,7 +120,7 @@ InputT = TypeVar('InputT') OutputT = TypeVar('OutputT') ChunkT = TypeVar('ChunkT') -P = ParamSpec('P') + R = TypeVar('R') T = TypeVar('T') @@ -260,16 +262,45 @@ def define_dynamic_action_provider( metadata=metadata, ) - def tool( - self, name: str | None = None, description: str | None = None - ) -> Callable[[Callable[P, T]], Callable[P, T]]: + def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable[..., Any]], Tool]: """Decorator to register a function as a tool.""" - def wrapper(func: Callable[P, T]) -> Callable[P, T]: + def wrapper(func: Callable[..., Any]) -> Tool: return define_tool(self.registry, func, name, description) return wrapper + def define_interrupt( + self, + name: str, + *, + input_schema: type[BaseModel] | dict[str, object] | None = None, + description: str | None = None, + ) -> Tool: + """Register an interrupt tool that always pauses for user input. + + Args: + name: Tool name + input_schema: Optional input schema (Pydantic model or JSON schema dict) + description: Tool description + + Returns: + The registered interrupt tool + + Example: + ask_user = ai.define_interrupt( + name='ask_user', + input_schema=Question, + description='Ask the user a question', + ) + """ + return define_interrupt( + self.registry, + name, + description=description, + input_schema=input_schema, + ) + def define_evaluator( self, *, @@ -393,7 +424,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -421,7 +452,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -449,7 +480,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -477,7 +508,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -503,7 +534,7 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, object] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -743,10 +774,12 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, - tool_responses: list[Part] | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -768,10 +801,12 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, - tool_responses: list[Part] | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -791,10 +826,12 @@ async def generate( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, - tool_responses: list[Part] | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -806,7 +843,12 @@ async def generate( use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, ) -> ModelResponse[Any]: - """Generate text or structured data using a language model.""" + """Generate text or structured data using a language model. + + ``tools`` is typed as ``Sequence`` rather than ``list`` because ``Sequence`` + is covariant: ``list[Tool]`` or ``list[str]`` are both assignable to + ``Sequence[str | Tool]``, but not to ``list[str | Tool]``. + """ return await generate_action( self.registry, await to_generate_action_options( @@ -819,7 +861,9 @@ async def generate( tools=tools, return_tool_requests=return_tool_requests, tool_choice=tool_choice, - tool_responses=tool_responses, + resume_respond=resume_respond, + resume_restart=resume_restart, + resume_metadata=resume_metadata, config=config, max_turns=max_turns, output_format=output_format, @@ -843,9 +887,12 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -868,9 +915,12 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -891,9 +941,12 @@ def generate_stream( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, config: dict[str, object] | ModelConfig | None = None, max_turns: int | None = None, context: dict[str, object] | None = None, @@ -922,6 +975,9 @@ async def _run_generate() -> ModelResponse[Any]: tools=tools, return_tool_requests=return_tool_requests, tool_choice=tool_choice, + resume_respond=resume_respond, + resume_restart=resume_restart, + resume_metadata=resume_metadata, config=config, max_turns=max_turns, output_format=output_format, @@ -1055,26 +1111,6 @@ def current_context() -> dict[str, Any] | None: """Get the current execution context, or None if not in an action.""" return ActionRunContext._current_context() # pyright: ignore[reportPrivateUsage] - def dynamic_tool( - self, - *, - name: str, - fn: Callable[..., object], - description: str | None = None, - metadata: dict[str, object] | None = None, - ) -> Action: - """Create an unregistered tool action for passing directly to generate().""" - tool_meta: dict[str, object] = metadata.copy() if metadata else {} - tool_meta['type'] = 'tool' - tool_meta['dynamic'] = True - return Action( - kind=ActionKind.TOOL, - name=name, - fn=fn, # type: ignore[arg-type] # dynamic tools may be sync - description=description, - metadata=tool_meta, - ) - async def flush_tracing(self) -> None: """Flush all pending trace spans to exporters.""" provider = trace_api.get_tracer_provider() @@ -1132,7 +1168,7 @@ async def generate_operation( prompt: str | list[Part] | None = None, system: str | list[Part] | None = None, messages: list[Message] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, return_tool_requests: bool | None = None, tool_choice: ToolChoice | None = None, config: dict[str, object] | ModelConfig | None = None, diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index 175e92455f..df29295e47 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -16,11 +16,12 @@ """Generate action.""" +import asyncio import contextlib import copy import inspect import re -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Any, cast from pydantic import BaseModel @@ -36,7 +37,7 @@ ModelResponseChunk, ) from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources -from genkit._ai._tools import ToolInterruptError +from genkit._ai._tools import Interrupt, Tool, run_tool_after_restart, unwrap_wrapped_scalar_tool_input_if_needed from genkit._core._action import Action, ActionKind, ActionRunContext from genkit._core._error import GenkitError from genkit._core._logger import get_logger @@ -60,6 +61,25 @@ logger = get_logger(__name__) +def tools_to_action_names( + tools: Sequence[str | Tool] | None, +) -> list[str] | None: + """Normalize tool arguments to registry names for GenerateActionOptions. + + Each item may be a tool name (``str``) or a Tool returned by + Genkit.tool(). + """ + if tools is None: + return None + names: list[str] = [] + for t in tools: + if isinstance(t, str): + names.append(t) + else: + names.append(t.name) + return names + + # Matches data URIs: everything up to the first comma is the media-type + # parameters (e.g. "data:audio/L16;codec=pcm;rate=24000;base64,"). _DATA_URI_RE = re.compile(r'data:[^,]{0,200},(?=.{100})', re.ASCII) @@ -589,10 +609,7 @@ async def resolve_parameters( tools: list[Action[Any, Any, Any]] = [] if request.tools: for tool_name in request.tools: - tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name) - if tool_action is None: - raise Exception(f'Unable to resolve tool {tool_name}') - tools.append(tool_action) + tools.append(await resolve_tool(registry, tool_name)) format_def: FormatDef | None = None if request.output and request.output.format: @@ -653,34 +670,37 @@ async def resolve_tool_requests( revised_model_message = message.model_copy(deep=True) - has_interrupts = False - response_parts: list[Part] = [] + work: list[tuple[int, Action, ToolRequestPart]] = [] i = 0 for tool_request_part in message.content: if not (isinstance(tool_request_part, Part) and isinstance(tool_request_part.root, ToolRequestPart)): # pyright: ignore[reportUnnecessaryIsInstance] i += 1 continue - # Type is now narrowed: tool_request_part.root is ToolRequestPart tool_req_root = tool_request_part.root tool_request = tool_req_root.tool_request if tool_request.name not in tool_dict: raise RuntimeError(f'failed {tool_request.name} not found') tool = tool_dict[tool_request.name] - tool_response_part, interrupt_part = await _resolve_tool_request(tool, tool_req_root) + work.append((i, tool, tool_req_root)) + i += 1 + + if not work: + return (None, Message(role=Role.TOOL, content=[]), None) + + outs = await asyncio.gather(*[_resolve_tool_request(tool, trp) for _, tool, trp in work]) + has_interrupts = False + response_parts: list[Part] = [] + for (idx, _tool, tool_req_root), (tool_response_part, interrupt_part) in zip(work, outs, strict=True): if tool_response_part: - # Extract the ToolResponsePart from the returned Part for _to_pending_response - if isinstance(tool_response_part.root, ToolResponsePart): - revised_model_message.content[i] = _to_pending_response(tool_req_root, tool_response_part.root) - response_parts.append(tool_response_part) + revised_model_message.content[idx] = _to_pending_response(tool_req_root, tool_response_part) + response_parts.append(Part(root=tool_response_part)) if interrupt_part: has_interrupts = True - revised_model_message.content[i] = interrupt_part - - i += 1 + revised_model_message.content[idx] = Part(root=interrupt_part) if has_interrupts: return (revised_model_message, None, None) @@ -701,48 +721,57 @@ def _to_pending_response(request: ToolRequestPart, response: ToolResponsePart) - ) -async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart) -> tuple[Part | None, Part | None]: - """Execute a tool and return (response_part, interrupt_part).""" +def _interrupt_from_tool_exc(exc: BaseException) -> Interrupt | None: + """If ``exc`` is (or wraps) an Interrupt exception, return that interrupt.""" + if isinstance(exc, Interrupt): + return exc + if isinstance(exc, GenkitError) and exc.cause is not None and isinstance(exc.cause, Interrupt): + return exc.cause + return None + + +async def _resolve_tool_request( + tool: Action, tool_request_part: ToolRequestPart +) -> tuple[ToolResponsePart | None, ToolRequestPart | None]: + """Execute a tool. + + Returns ``(ToolResponsePart, None)`` on success or ``(None, ToolRequestPart)`` when interrupted. + """ try: - tool_response = (await tool.run(tool_request_part.tool_request.input)).response - # Part is a RootModel, so we pass content via 'root' parameter + tool_in = unwrap_wrapped_scalar_tool_input_if_needed( + tool_request_part.tool_request.input, + tool.input_schema, + ) + tool_response = (await tool.run(tool_in)).response return ( - Part( - root=ToolResponsePart( - tool_response=ToolResponse( - name=tool_request_part.tool_request.name, - ref=tool_request_part.tool_request.ref, - output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, - ) + ToolResponsePart( + tool_response=ToolResponse( + name=tool_request_part.tool_request.name, + ref=tool_request_part.tool_request.ref, + output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, ) ), None, ) - except GenkitError as e: - if e.cause and isinstance(e.cause, ToolInterruptError): - interrupt_error = e.cause - # Part is a RootModel, so we pass content via 'root' parameter + except Exception as e: + intr = _interrupt_from_tool_exc(e) + if intr is not None: + payload: dict[str, Any] | bool = intr.data if intr.data else True tool_meta = tool_request_part.metadata or {} if not isinstance(tool_meta, dict): tool_meta = dict(tool_meta) return ( None, - Part( - root=ToolRequestPart( - tool_request=tool_request_part.tool_request, - metadata={ - **tool_meta, - 'interrupt': (interrupt_error.metadata if interrupt_error.metadata else True), - }, - ) + ToolRequestPart( + tool_request=tool_request_part.tool_request, + metadata={**tool_meta, 'interrupt': payload}, ), ) - - raise e + raise async def resolve_tool(registry: Registry, tool_name: str) -> Action: - """Resolve a tool by name from the registry.""" + """Resolve a tool action by name from the registry.""" tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_name) if tool is None: raise ValueError(f'Unable to resolve tool {tool_name}') @@ -777,9 +806,9 @@ async def _resolve_resume_options( i += 1 continue - resumed_request, resumed_response = _resolve_resumed_tool_request(raw_request, part) - tool_responses.append(resumed_response) - updated_content[i] = resumed_request + resumed_request, resumed_response = await _resolve_resumed_tool_request(_registry, raw_request, part) + tool_responses.append(Part(root=resumed_response)) + updated_content[i] = Part(root=resumed_request) i += 1 last_message.content = updated_content @@ -802,8 +831,10 @@ async def _resolve_resume_options( return (revised_request, None, tool_message) -def _resolve_resumed_tool_request(raw_request: GenerateActionOptions, tool_request_part: Part) -> tuple[Part, Part]: - """Resolve a single tool request from pending output or resume.respond list.""" +async def _resolve_resumed_tool_request( + registry: Registry, raw_request: GenerateActionOptions, tool_request_part: Part +) -> tuple[ToolRequestPart, ToolResponsePart]: + """Resolve a single tool request from pending output, resume.respond, or resume.restart.""" # Type narrowing: ensure we're working with a ToolRequestPart if not isinstance(tool_request_part.root, ToolRequestPart): raise GenkitError( @@ -814,22 +845,24 @@ def _resolve_resumed_tool_request(raw_request: GenerateActionOptions, tool_reque tool_req_root = tool_request_part.root if tool_req_root.metadata and 'pendingOutput' in tool_req_root.metadata: - metadata = dict(tool_req_root.metadata) - pending_output = metadata['pendingOutput'] - del metadata['pendingOutput'] - metadata['source'] = 'pending' + # resolveResumedToolRequest: strip pendingOutput from the model TRP; reconstruct + # output on the tool message with metadata { ...rest, source: 'pending' }. + trp_metadata = dict(tool_req_root.metadata) + pending_output = trp_metadata.pop('pendingOutput') + revised_trp = ToolRequestPart( + tool_request=tool_req_root.tool_request, + metadata=trp_metadata if trp_metadata else None, + ) + response_metadata = {**trp_metadata, 'source': 'pending'} return ( - tool_request_part, - # Part is a RootModel, so we pass content via 'root' parameter - Part( - root=ToolResponsePart( - tool_response=ToolResponse( - name=tool_req_root.tool_request.name, - ref=tool_req_root.tool_request.ref, - output=pending_output.model_dump() if isinstance(pending_output, BaseModel) else pending_output, - ), - metadata=metadata, - ) + revised_trp, + ToolResponsePart( + tool_response=ToolResponse( + name=tool_req_root.tool_request.name, + ref=tool_req_root.tool_request.ref, + output=pending_output.model_dump() if isinstance(pending_output, BaseModel) else pending_output, + ), + metadata=response_metadata, ), ) @@ -845,20 +878,40 @@ def _resolve_resumed_tool_request(raw_request: GenerateActionOptions, tool_reque if interrupt: del metadata['interrupt'] return ( - # Part is a RootModel, so we pass content via 'root' parameter - Part( - root=ToolRequestPart( - tool_request=ToolRequest( - name=tool_req_root.tool_request.name, - ref=tool_req_root.tool_request.ref, - input=tool_req_root.tool_request.input, - ), - metadata={**metadata, 'resolvedInterrupt': interrupt}, - ) + ToolRequestPart( + tool_request=ToolRequest( + name=tool_req_root.tool_request.name, + ref=tool_req_root.tool_request.ref, + input=tool_req_root.tool_request.input, + ), + metadata={**metadata, 'resolvedInterrupt': interrupt}, ), provided_response, ) + restart_trp = _find_corresponding_restart( + raw_request.resume.restart if raw_request.resume else None, + tool_req_root, + ) + if restart_trp: + tool = await resolve_tool(registry, tool_req_root.tool_request.name) + executed = await run_tool_after_restart(tool, restart_trp) + metadata = dict(tool_req_root.metadata) if tool_req_root.metadata else {} + interrupt = metadata.get('interrupt') + if interrupt: + del metadata['interrupt'] + return ( + ToolRequestPart( + tool_request=ToolRequest( + name=tool_req_root.tool_request.name, + ref=tool_req_root.tool_request.ref, + input=tool_req_root.tool_request.input, + ), + metadata={**metadata, 'resolvedInterrupt': interrupt}, + ), + executed, + ) + raise GenkitError( status='INVALID_ARGUMENT', message=f"Unresolved tool request '{tool_req_root.tool_request.name}' " @@ -867,11 +920,26 @@ def _resolve_resumed_tool_request(raw_request: GenerateActionOptions, tool_reque ) -def _find_corresponding_tool_response(responses: list[ToolResponsePart], request: ToolRequestPart) -> Part | None: +def _find_corresponding_restart( + restarts: list[ToolRequestPart] | None, + request: ToolRequestPart, +) -> ToolRequestPart | None: + """Find a restart part matching the pending request by name and ref.""" + if not restarts: + return None + for trp in restarts: + if trp.tool_request.name == request.tool_request.name and trp.tool_request.ref == request.tool_request.ref: + return trp + return None + + +def _find_corresponding_tool_response( + responses: list[ToolResponsePart], request: ToolRequestPart +) -> ToolResponsePart | None: """Find a response matching the request by name and ref.""" for p in responses: if p.tool_response.name == request.tool_request.name and p.tool_response.ref == request.tool_request.ref: - return Part(root=p) + return p return None diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index 82e0985ab3..a22437959a 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -20,7 +20,7 @@ import asyncio import os import weakref -from collections.abc import AsyncIterable, Awaitable, Callable +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass from pathlib import Path from typing import Any, ClassVar, Generic, TypedDict, TypeVar, cast @@ -32,10 +32,13 @@ PromptMetadata, ) from pydantic import BaseModel, ConfigDict +from typing_extensions import Unpack from genkit._ai._generate import ( generate_action, + resolve_tool, to_tool_definition, + tools_to_action_names, ) from genkit._ai._model import ( Message, @@ -44,6 +47,7 @@ ModelResponse, ModelResponseChunk, ) +from genkit._ai._tools import Tool from genkit._core._action import Action, ActionKind, ActionRunContext, StreamingCallback, create_action_key from genkit._core._channel import Channel from genkit._core._error import GenkitError @@ -83,12 +87,34 @@ class OutputOptions(TypedDict, total=False): constrained: bool | None -class ResumeOptions(TypedDict, total=False): - """Options for resuming generation after a tool interrupt.""" +def _normalize_resume_respond_parts( + value: ToolResponsePart | list[ToolResponsePart] | None, +) -> list[ToolResponsePart] | None: + if value is None: + return None + return list(value) if isinstance(value, list) else [value] + + +def _normalize_resume_restart_parts( + value: ToolRequestPart | list[ToolRequestPart] | None, +) -> list[ToolRequestPart] | None: + if value is None: + return None + return list(value) if isinstance(value, list) else [value] - respond: ToolResponsePart | list[ToolResponsePart] | None - restart: ToolRequestPart | list[ToolRequestPart] | None - metadata: dict[str, Any] | None + +def resume_options_to_resume( + *, + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None, + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None, + resume_metadata: dict[str, Any] | None = None, +) -> Resume | None: + """Build wire Resume from flat keyword options (``generate`` / prompts).""" + respond = _normalize_resume_respond_parts(resume_respond) + restart = _normalize_resume_restart_parts(resume_restart) + if respond is None and restart is None and resume_metadata is None: + return None + return Resume(respond=respond, restart=restart, metadata=resume_metadata) class PromptGenerateOptions(TypedDict, total=False): @@ -98,11 +124,13 @@ class PromptGenerateOptions(TypedDict, total=False): config: dict[str, Any] | ModelConfig | None messages: list[Message] | None docs: list[Document] | None - tools: list[str] | None + tools: Sequence[str | Tool] | None resources: list[str] | None tool_choice: ToolChoice | None output: OutputOptions | None - resume: ResumeOptions | None + resume_respond: ToolResponsePart | list[ToolResponsePart] | None + resume_restart: ToolRequestPart | list[ToolRequestPart] | None + resume_metadata: dict[str, Any] | None return_tool_requests: bool | None max_turns: int | None on_chunk: ModelStreamingCallback | None @@ -112,6 +140,27 @@ class PromptGenerateOptions(TypedDict, total=False): metadata: dict[str, Any] | None +_PROMPT_GENERATE_OPTION_KEYS: frozenset[str] = frozenset(PromptGenerateOptions.__annotations__) + + +def _coerce_prompt_opts(opts: Mapping[str, Any]) -> PromptGenerateOptions: + """Build effective opts from a kwargs mapping: drop keys whose value is None. + + Rejects unknown keys at runtime (Unpack only enforces this for type checkers). + """ + raw = dict(opts) + unknown = set(raw) - _PROMPT_GENERATE_OPTION_KEYS + if unknown: + if 'opts' in unknown: + raise TypeError( + 'Passing a combined `opts` dict is not supported; use keyword arguments ' + '(e.g. model=..., config=...) matching PromptGenerateOptions.' + ) + sorted_unknown = ', '.join(sorted(unknown)) + raise TypeError(f'Unexpected keyword arguments for prompt execution: {sorted_unknown}') + return cast(PromptGenerateOptions, {k: v for k, v in raw.items() if v is not None}) + + class ModelStreamResponse(Generic[OutputT]): """Response from streaming prompt execution with stream and response properties.""" @@ -183,11 +232,13 @@ class PromptConfig(BaseModel): max_turns: int | None = None return_tool_requests: bool | None = None metadata: dict[str, Any] | None = None - tools: list[str] | None = None + tools: Sequence[str | Tool] | None = None tool_choice: ToolChoice | None = None use: list[ModelMiddleware] | None = None docs: list[Document] | None = None - tool_responses: list[Part] | None = None + resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None + resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None + resume_metadata: dict[str, Any] | None = None resources: list[str] | None = None @@ -213,7 +264,7 @@ def __init__( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, Any] | None = None, - tools: list[str] | None = None, + tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, docs: list[Document] | None = None, @@ -295,60 +346,42 @@ async def _ensure_resolved(self) -> None: async def __call__( self, - input: InputT | None = None, - opts: PromptGenerateOptions | None = None, + input: InputT | dict[str, Any] | None = None, + **opts: Unpack[PromptGenerateOptions], ) -> ModelResponse[OutputT]: - """Execute the prompt and return the response.""" - await self._ensure_resolved() - effective_opts: PromptGenerateOptions = opts if opts else {} + """Execute the prompt and return the response. - # Extract streaming callback and middleware from opts - on_chunk = effective_opts.get('on_chunk') - middleware = effective_opts.get('use') or self._use - context = effective_opts.get('context') + Args: + input: Template variables for rendering. + """ + effective_opts = _coerce_prompt_opts(opts) + return await self._call_impl(input, effective_opts) + async def _call_impl( + self, + input: InputT | dict[str, Any] | None, + opts: PromptGenerateOptions, + ) -> ModelResponse[OutputT]: + """Execute the prompt with resolved opts. Used by __call__ and stream.""" + await self._ensure_resolved() + on_chunk = opts.get('on_chunk') + middleware = opts.get('use') or self._use + context = opts.get('context') result = await generate_action( self._registry, - await self.render(input=input, opts=effective_opts), + await self._render_impl(input, opts), on_chunk=on_chunk, middleware=middleware, context=context if context else ActionRunContext._current_context(), # pyright: ignore[reportPrivateUsage] ) - # Cast to preserve the generic type parameter return cast(ModelResponse[OutputT], result) - def stream( - self, - input: InputT | None = None, - opts: PromptGenerateOptions | None = None, - *, - timeout: float | None = None, - ) -> ModelStreamResponse[OutputT]: - """Stream the prompt execution, returning (stream, response_future).""" - effective_opts: PromptGenerateOptions = opts if opts else {} - channel: Channel[ModelResponseChunk, ModelResponse[OutputT]] = Channel(timeout=timeout) - - # Create a copy of opts with the streaming callback - stream_opts: PromptGenerateOptions = { - **effective_opts, - 'on_chunk': lambda c: channel.send(cast(ModelResponseChunk, c)), - } - - resp = self.__call__(input=input, opts=stream_opts) - response_future: asyncio.Future[ModelResponse[OutputT]] = asyncio.create_task(resp) - channel.set_close_future(response_future) - - return ModelStreamResponse[OutputT](channel=channel, response_future=response_future) - - async def render( + async def _render_impl( self, - input: InputT | dict[str, Any] | None = None, - opts: PromptGenerateOptions | None = None, + input: InputT | dict[str, Any] | None, + opts: PromptGenerateOptions, ) -> GenerateActionOptions: - """Render the prompt template without executing, returning GenerateActionOptions.""" - await self._ensure_resolved() - if opts is None: - opts = cast(PromptGenerateOptions, {}) + """Render the prompt with resolved opts. Used by render() and _call_impl.""" output_opts = opts.get('output') or {} context = opts.get('context') @@ -395,46 +428,37 @@ def _or(opt_val: Any, default: Any) -> Any: # noqa: ANN401 metadata=merged_metadata, docs=self._docs, resources=opts.get('resources') or self._resources, + resume_respond=opts.get('resume_respond'), + resume_restart=opts.get('resume_restart'), + resume_metadata=opts.get('resume_metadata'), ) - model = prompt_options.model or self._registry.default_model - if model is None: + model_name = prompt_options.model or self._registry.default_model + if model_name is None: raise GenkitError(status='INVALID_ARGUMENT', message='No model configured.') resolved_msgs: list[Message] = [] # Convert input to dict for render functions - # If input is a Pydantic model, convert to dict; otherwise use as-is render_input: dict[str, Any] if input is None: render_input = {} elif isinstance(input, dict): - # Type narrow: input is dict here, assign to dict[str, Any] typed variable render_input = {str(k): v for k, v in input.items()} elif isinstance(input, BaseModel): - # Pydantic v2 model render_input = input.model_dump() elif hasattr(input, 'dict'): - # Pydantic v1 model dict_func = getattr(input, 'dict', None) render_input = cast(Callable[[], dict[str, Any]], dict_func)() else: - # Fallback: cast to dict (should not happen with proper typing) render_input = cast(dict[str, Any], input) - # Get opts.messages for history (matching JS behavior) - opts_messages = opts.get('messages') - # Render system prompt + opts_messages = opts.get('messages') if prompt_options.system: result = await render_system_prompt( self._registry, render_input, prompt_options, self._cache_prompt, context ) resolved_msgs.append(result) - - # Handle messages (matching JS behavior): - # - If prompt has messages template: render it (opts.messages passed as history to resolvers) - # - If prompt has no messages: use opts.messages directly if prompt_options.messages: - # Prompt defines messages - render them (resolvers receive opts_messages as history) resolved_msgs.extend( await render_message_prompt( self._registry, @@ -446,89 +470,83 @@ def _or(opt_val: Any, default: Any) -> Any: # noqa: ANN401 ) ) elif opts_messages: - # Prompt has no messages template - use opts.messages directly resolved_msgs.extend(opts_messages) - - # Render user prompt if prompt_options.prompt: result = await render_user_prompt(self._registry, render_input, prompt_options, self._cache_prompt, context) resolved_msgs.append(result) - # If schema is set but format is not explicitly set, default to 'json' format if prompt_options.output_schema and not prompt_options.output_format: - output_format = 'json' + out_format = 'json' else: - output_format = prompt_options.output_format + out_format = prompt_options.output_format - # Build output config - output = GenerateActionOutputConfig() - if output_format: - output.format = output_format + output_config = GenerateActionOutputConfig() + if out_format: + output_config.format = out_format if prompt_options.output_content_type: - output.content_type = prompt_options.output_content_type + output_config.content_type = prompt_options.output_content_type if prompt_options.output_instructions is not None: - output.instructions = prompt_options.output_instructions - _resolve_output_schema(self._registry, prompt_options.output_schema, output) + output_config.instructions = prompt_options.output_instructions + _resolve_output_schema(self._registry, prompt_options.output_schema, output_config) if prompt_options.output_constrained is not None: - output.constrained = prompt_options.output_constrained + output_config.constrained = prompt_options.output_constrained - # Handle resume options - resume = None - resume_opts = opts.get('resume') - if resume_opts: - respond = resume_opts.get('respond') - if respond: - resume = Resume(respond=respond) if isinstance(respond, list) else Resume(respond=[respond]) + resume_result = resume_options_to_resume( + resume_respond=opts.get('resume_respond'), + resume_restart=opts.get('resume_restart'), + resume_metadata=opts.get('resume_metadata'), + ) - # Merge docs: opts.docs extends prompt docs merged_docs = await render_docs(render_input, prompt_options, context) opts_docs = opts.get('docs') if opts_docs: merged_docs = [*merged_docs, *opts_docs] if merged_docs else list(opts_docs) return GenerateActionOptions( - model=model, - messages=resolved_msgs, + model=model_name, + messages=resolved_msgs, # type: ignore[arg-type] config=prompt_options.config, - tools=prompt_options.tools, + tools=tools_to_action_names(prompt_options.tools), return_tool_requests=prompt_options.return_tool_requests, tool_choice=prompt_options.tool_choice, - output=output, + output=output_config, max_turns=prompt_options.max_turns, docs=merged_docs, # type: ignore[arg-type] - resume=resume, + resume=resume_result, ) - async def as_tool(self) -> Action: - """Expose this prompt as a tool. - - Returns the PROMPT action, which can be used as a tool. - """ - await self._ensure_resolved() - # If we have a direct reference to the action, use it - if self._prompt_action is not None: - return self._prompt_action - - # Otherwise, try to look it up using name/variant/ns - if self._name is None: - raise GenkitError( - status='FAILED_PRECONDITION', - message=( - 'Prompt name not available. This prompt was not created via define_prompt_async() or load_prompt().' - ), - ) - - lookup_key = registry_lookup_key(self._name, self._variant, self._ns) + def stream( + self, + input: InputT | dict[str, Any] | None = None, + *, + timeout: float | None = None, + **opts: Unpack[PromptGenerateOptions], + ) -> ModelStreamResponse[OutputT]: + """Stream the prompt execution, returning (stream, response_future).""" + effective_opts = _coerce_prompt_opts(opts) + channel: Channel[ModelResponseChunk, ModelResponse[OutputT]] = Channel(timeout=timeout) + stream_opts: PromptGenerateOptions = { + **effective_opts, + 'on_chunk': lambda c: channel.send(cast(ModelResponseChunk, c)), + } + resp = self._call_impl(input, stream_opts) + response_future: asyncio.Future[ModelResponse[OutputT]] = asyncio.create_task(resp) + channel.set_close_future(response_future) - action = await self._registry.resolve_action_by_key(lookup_key) + return ModelStreamResponse[OutputT](channel=channel, response_future=response_future) - if action is None or action.kind != ActionKind.PROMPT: - raise GenkitError( - status='NOT_FOUND', - message=f'PROMPT action not found for prompt "{self._name}"', - ) + async def render( + self, + input: InputT | dict[str, Any] | None = None, + **opts: Unpack[PromptGenerateOptions], + ) -> GenerateActionOptions: + """Render the prompt template without executing, returning GenerateActionOptions. - return action + Same keyword options as ``__call__`` (see PromptGenerateOptions). + """ + await self._ensure_resolved() + coerced = _coerce_prompt_opts(opts) + return await self._render_impl(input, coerced) def register_prompt_actions( @@ -651,18 +669,20 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig) if options.output_constrained is not None: output.constrained = options.output_constrained - resume = None - if options.tool_responses: - # Filter for only ToolResponsePart instances - tool_response_parts = [r.root for r in options.tool_responses if isinstance(r.root, ToolResponsePart)] - if tool_response_parts: - resume = Resume(respond=tool_response_parts) + resume = resume_options_to_resume( + resume_respond=options.resume_respond, + resume_restart=options.resume_restart, + resume_metadata=options.resume_metadata, + ) + + # Convert tool refs (str name or Tool object) to string names for GenerateActionOptions + tools_refs = tools_to_action_names(options.tools) return GenerateActionOptions( model=model, messages=resolved_msgs, # type: ignore[arg-type] config=options.config, - tools=options.tools, + tools=tools_refs, return_tool_requests=options.return_tool_requests, tool_choice=options.tool_choice, output=output, @@ -676,11 +696,8 @@ async def to_generate_request(registry: Registry, options: GenerateActionOptions """Convert GenerateActionOptions to ModelRequest, resolving tool names.""" tools: list[Action] = [] if options.tools: - for tool_name in options.tools: - tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name) - if tool_action is None: - raise GenkitError(status='NOT_FOUND', message=f'Unable to resolve tool {tool_name}') - tools.append(tool_action) + for tool_ref in options.tools: + tools.append(await resolve_tool(registry, tool_ref)) tool_defs = [to_tool_definition(tool) for tool in tools] if tools else [] diff --git a/py/packages/genkit/src/genkit/_ai/_tools.py b/py/packages/genkit/src/genkit/_ai/_tools.py index f603f66f79..3dc4be7732 100644 --- a/py/packages/genkit/src/genkit/_ai/_tools.py +++ b/py/packages/genkit/src/genkit/_ai/_tools.py @@ -18,15 +18,145 @@ import inspect from collections.abc import Callable -from functools import wraps -from typing import Any, NoReturn, ParamSpec, TypeVar, cast +from contextvars import ContextVar +from typing import Any, cast -from genkit._core._action import ActionKind, ActionRunContext +from pydantic import BaseModel + +from genkit._core._action import Action, ActionKind, ActionRunContext +from genkit._core._error import GenkitError from genkit._core._registry import Registry -from genkit._core._typing import Part, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart +from genkit._core._typing import ToolDefinition, ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart + + +class Tool: + """A registered tool: a callable handle backed by an :class:`~genkit._core._action.Action`. + + Obtain instances via :func:`define_tool`, :func:`define_interrupt`, or the + ``@ai.tool`` decorator rather than constructing directly. + """ + + def __init__(self, action: Action) -> None: + self._action = action + + @property + def name(self) -> str: + """Tool name (registry key).""" + return self._action.name + + @property + def description(self) -> str: + """Human-readable description sent to the model.""" + return self._action.description or '' + + @property + def input_schema(self) -> dict[str, object] | None: + """JSON Schema for the tool's input, as sent on the wire.""" + return self._action.input_schema + + @property + def output_schema(self) -> dict[str, object] | None: + """JSON Schema for the tool's output.""" + return self._action.output_schema + + def definition(self) -> ToolDefinition: + """Return the wire-format ToolDefinition for this tool.""" + return ToolDefinition( + name=self.name, + description=self.description, + input_schema=self.input_schema, + output_schema=self.output_schema, + ) + + async def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + """Run the tool and return the unwrapped response value.""" + return (await self._action.run(*args, **kwargs)).response + + def restart( + self, + replace_input: Any | None = None, # noqa: ANN401 + *, + interrupt: ToolRequestPart, + resumed_metadata: dict[str, Any] | None = None, + ) -> ToolRequestPart: + """Create a restart request for an interrupted tool call. + + Args: + replace_input: Optional new ``tool_request.input`` for this run (previous input is + stored in ``metadata.replacedInput`` when this is set). + interrupt: The interrupted ``ToolRequestPart`` (e.g. from ``response.interrupts``). + resumed_metadata: Passed to the tool as ``ToolRunContext.resumed_metadata``. + + Returns: + A ``ToolRequestPart`` for ``resume_restart`` / message history. + + Example: + ``pay_invoice.restart({**trp.tool_request.input, "confirmed": True}, interrupt=trp,`` + ``resumed_metadata={"by": "bob"})`` + """ + tool_req = interrupt.tool_request + if tool_req.name != self.name: + raise ValueError(f"Interrupt is for tool '{tool_req.name}', not '{self.name}'") + + existing_meta = interrupt.metadata or {} + new_meta: dict[str, Any] = dict(existing_meta) if existing_meta else {} + + new_meta['resumed'] = resumed_metadata if resumed_metadata is not None else True + + new_input = tool_req.input + if replace_input is not None: + new_meta['replacedInput'] = tool_req.input + new_input = replace_input + + return ToolRequestPart( + tool_request=ToolRequest( + name=tool_req.name, + ref=tool_req.ref, + input=new_input, + ), + metadata=new_meta, + ) -P = ParamSpec('P') -T = TypeVar('T') + +# Context variables for propagating resumed metadata to tools +_tool_resumed_metadata: ContextVar[dict[str, Any] | None] = ContextVar('tool_resumed_metadata', default=None) +# Stashed copy of tool_request.input when restart replaces input (JSON; shape is per tool). +_tool_original_input: ContextVar[Any | None] = ContextVar('tool_original_input', default=None) # noqa: ANN401 + + +def _json_schema_root_is_scalar_or_array(schema: dict[str, object] | None) -> bool: + """Return True if the JSON Schema root is a scalar or array type.""" + if not schema: + return False + t = schema.get('type') + if isinstance(t, str): + tl = t.lower() + return tl in ('string', 'number', 'integer', 'boolean', 'array') + return False + + +def unwrap_wrapped_scalar_tool_input_if_needed( + input_payload: Any, # noqa: ANN401 - wire JSON from model; schema varies per tool + input_schema: dict[str, object] | None, +) -> Any: # noqa: ANN401 - same payload after optional {"value": x} unwrap + """Unwrap ``{"value": x}`` before calling a tool whose JSON Schema root is scalar/array. + + The Google Genai Gemini plugin wraps scalar/array roots as ``{"value": }`` + at declaration time so the Gemini API accepts them. The model then sends arguments + in that same shape. This helper strips the wrapper so the tool handler receives the + bare value it was defined with. + + Note: ``ToolRequest.input`` is NOT modified — this unwrap happens only at call time + so that message history keeps the original wire format (required for subsequent + Gemini turns where ``FunctionCall.args`` must be a dict). + """ + if not _json_schema_root_is_scalar_or_array(input_schema): + return input_payload + if not isinstance(input_payload, dict): + return input_payload + if set(input_payload.keys()) != {'value'}: + return input_payload + return input_payload['value'] class ToolRunContext(ActionRunContext): @@ -35,49 +165,126 @@ class ToolRunContext(ActionRunContext): def __init__( self, ctx: ActionRunContext, + resumed_metadata: dict[str, Any] | None = None, + original_input: Any = None, # noqa: ANN401 - prior tool_request.input when replacing on restart ) -> None: - """Initialize from parent ActionRunContext.""" + """Initialize from parent ActionRunContext. + + Args: + ctx: Parent action context + resumed_metadata: Metadata from previous interrupt (if resumed) + original_input: Original tool input before replacement (if resumed) + """ super().__init__(context=ctx.context) + self.resumed_metadata = resumed_metadata + self.original_input = original_input - def interrupt(self, metadata: dict[str, Any] | None = None) -> NoReturn: - """Raise ToolInterruptError to pause execution (e.g., for user input).""" - raise ToolInterruptError(metadata=metadata) + def is_resumed(self) -> bool: + """Return True if this execution is resuming after an interrupt.""" + return self.resumed_metadata is not None -# TODO(#4346): make this extend GenkitError once it has INTERRUPTED status -class ToolInterruptError(Exception): - """Controlled interruption of tool execution (e.g., to request user input).""" +class Interrupt(Exception): # noqa: N818 - public Genkit name; not renamed *Error for style + """Exception for interrupting tool execution with user-facing API. + + Raise ``Interrupt(data)`` from a tool or from tool middleware (e.g. ``wrap_tool``). + Exceptions from ``tool.run`` are wrapped in GenkitError + with ``cause=Interrupt``; generation attaches interrupt metadata to the pending tool + request. + + To resume, use ``respond_to_interrupt`` or ``tool.restart(...)`` on the + registered Tool. + """ - def __init__(self, metadata: dict[str, Any] | None = None) -> None: - """Initialize with optional interrupt metadata.""" + def __init__(self, data: dict[str, Any] | None = None) -> None: + """Initialize an Interrupt exception. + + Args: + data: Interrupt metadata (attached to the tool request on the wire). Use a + plain dict; for a Pydantic model, pass ``m.model_dump(mode="json")``. + """ super().__init__() - self.metadata: dict[str, Any] = metadata or {} + self.data: dict[str, Any] = {} if data is None else data -def tool_response( - interrupt: Part | ToolRequestPart, - response_data: object | None = None, +def _tool_response_part( + interrupt: ToolRequestPart, + output: Any, # noqa: ANN401 - arbitrary tool/interrupt reply payload (JSON) metadata: dict[str, Any] | None = None, -) -> Part: - """Create a ToolResponse Part for an interrupted tool request.""" - # TODO(#4347): validate against tool schema - tool_request = interrupt.root.tool_request if isinstance(interrupt, Part) else interrupt.tool_request - - interrupt_metadata: dict[str, Any] | bool = True - if isinstance(metadata, dict): - interrupt_metadata = metadata - elif metadata: - interrupt_metadata = metadata - - tr = cast(ToolRequest, tool_request) - return Part( - root=ToolResponsePart( - tool_response=ToolResponse( - name=tr.name, - ref=tr.ref, - output=response_data, - ), - metadata={'interruptResponse': interrupt_metadata}, +) -> ToolResponsePart: + """Build a ``ToolResponsePart`` for an interrupted tool request (interrupt reply channel).""" + interrupt_metadata = metadata if metadata is not None else True + tool_req = interrupt.tool_request + return ToolResponsePart( + tool_response=ToolResponse( + ref=tool_req.ref, + name=tool_req.name, + output=output, + ), + metadata={'interruptResponse': interrupt_metadata}, + ) + + +def respond_to_interrupt( + response: Any, # noqa: ANN401 - user reply or tool output for resume_respond + *, + interrupt: ToolRequestPart, + metadata: dict[str, Any] | None = None, +) -> ToolResponsePart: + """Build a ``ToolResponsePart`` for a pending tool interrupt. + + Pass the return value to ``generate(..., resume_respond=interrupt_response)``. + + Args: + response: Tool output / user reply for this interrupt. + interrupt: The interrupted ``ToolRequestPart`` (e.g. from ``response.interrupts``). + metadata: Optional metadata for the interrupt response channel. + """ + return _tool_response_part(interrupt, response, metadata) + + +async def run_tool_after_restart(tool: Action[Any, Any, Any], restart_trp: ToolRequestPart) -> ToolResponsePart: + """Run a tool for ``resume_restart``: applies ``resumed`` / ``replacedInput`` from metadata. + + Sets the same context variables as the tool wrapper so ToolRunContext reflects + a resumed run. Nested interrupts during restart are not supported and raise GenkitError. + """ + meta = restart_trp.metadata or {} + raw_resumed = meta.get('resumed') + if raw_resumed is True: + resumed_meta: dict[str, Any] | None = {} + elif isinstance(raw_resumed, dict): + resumed_meta = raw_resumed + else: + resumed_meta = None + original_input = meta.get('replacedInput') + + token_meta = _tool_resumed_metadata.set(resumed_meta) + token_input = _tool_original_input.set(original_input) + try: + try: + tool_in = unwrap_wrapped_scalar_tool_input_if_needed( + restart_trp.tool_request.input, + tool.input_schema, + ) + tool_response = (await tool.run(tool_in)).response + except GenkitError as e: + if e.cause and isinstance(e.cause, Interrupt): + raise GenkitError( + status='FAILED_PRECONDITION', + message='Tool interrupted again during a restart execution; not supported yet.', + cause=e.cause, + ) from e + raise + finally: + _tool_resumed_metadata.reset(token_meta) + _tool_original_input.reset(token_input) + + return ToolResponsePart( + tool_response=ToolResponse( + name=restart_trp.tool_request.name, + ref=restart_trp.tool_request.ref, + output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, ) ) @@ -91,43 +298,50 @@ def _get_func_description(func: Callable[..., Any], description: str | None = No return '' -def define_tool( +def _define_tool( registry: Registry, - func: Callable[P, T], + func: Callable[..., Any], name: str | None = None, description: str | None = None, -) -> Callable[P, T]: + *, + input_schema: type[BaseModel] | dict[str, object] | None = None, +) -> Tool: """Register a function as a tool. - Args: - registry: The registry to register the tool in. - func: The async function to register as a tool. Must be a coroutine function. - name: Optional name for the tool. Defaults to the function name. - description: Optional description. Defaults to the function's docstring. + Normally, the input_schema and output_schem are inferred from func. However, + in some cases, like define_interrupt, the app developer doesn't have a way to + express the input schema in the func signature. - Raises: - TypeError: If func is not an async function. + In that case, the app developer can pass in an input_schema to override the inferred schema. + This will ensure that the model requesting the tool will see the correct input shape. """ - # All Python functions have __name__, but ty is strict about Callable protocol if not inspect.iscoroutinefunction(func): - raise TypeError(f'Tool function must be async. Got sync function: {func.__name__}') # ty: ignore[unresolved-attribute] + raise TypeError(f'Tool function must be async. Got sync function: {getattr(func, "__name__", repr(func))}') tool_name = name if name is not None else getattr(func, '__name__', 'unnamed_tool') tool_description = _get_func_description(func, description) input_spec = inspect.getfullargspec(func) - func_any = cast(Callable[..., Any], func) - - async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 - # Dynamic dispatch based on function signature - pyright can't verify ParamSpec here + async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 - arity dispatch; args/return follow registered tool + # Dynamic dispatch by arity; payload types follow the registered tool (not expressible here). match len(input_spec.args): case 0: - return await func_any() + return await func() case 1: - return await func_any(args[0]) + return await func(args[0]) case 2: - return await func_any(args[0], ToolRunContext(cast(ActionRunContext, args[1]))) + # Read from context variables for resumed metadata + resumed_meta = _tool_resumed_metadata.get() + original_input = _tool_original_input.get() + return await func( + args[0], + ToolRunContext( + cast(ActionRunContext, args[1]), + resumed_metadata=resumed_meta, + original_input=original_input, + ), + ) case _: raise ValueError('tool must have 0-2 args...') @@ -138,10 +352,85 @@ async def tool_fn_wrapper(*args: Any) -> Any: # noqa: ANN401 fn=tool_fn_wrapper, metadata_fn=func, ) + if input_schema is not None: + action._override_input_schema(input_schema) + + return Tool(action) + + +def define_tool( + registry: Registry, + func: Callable[..., Any], + name: str | None = None, + description: str | None = None, +) -> Tool: + """Register a function as a tool. + + Tool input/output JSON Schemas are inferred from ``func`` (first parameter and return type). - @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any: # noqa: ANN401 - action_any = cast(Any, action) - return (await action_any.run(*args, **kwargs)).response + Args: + registry: The registry to register the tool in. + func: The async function to register as a tool. Must be a coroutine function. + name: Optional name for the tool. Defaults to the function name. + description: Optional description. Defaults to the function's docstring. + + Raises: + TypeError: If func is not an async function. + """ + return _define_tool(registry, func, name, description) - return cast(Callable[P, T], wrapper) + +def define_interrupt( + registry: Registry, + name: str, + *, + description: str | None = None, + request_metadata: dict[str, Any] | Callable[[Any], dict[str, Any]] | None = None, # noqa: ANN401 + input_schema: type[BaseModel] | dict[str, object] | None = None, +) -> Tool: + """Register a tool that always interrupts execution. + + An interrupt tool is a special tool that always raises ``Interrupt`` with + optional metadata. This is useful for explicit human-in-the-loop checkpoints. + For tools that sometimes run logic and sometimes interrupt, use ``define_tool`` + and raise ``Interrupt`` from the handler (or use ``ToolRunContext``). + + Args: + registry: The registry to register the interrupt tool in + name: Tool name (registry key) + description: Tool description shown to the model + request_metadata: Static metadata dict or ``(input) -> dict`` for the interrupt + input_schema: Optional wire input schema (Pydantic model or JSON schema dict). The + interrupt handler is typed as ``Any``; pass this so the model sees a concrete shape. + + Returns: + The registered tool callable (same shape as ``define_tool``). + + Example: + def get_meta(input: dict) -> dict: + return {"action": input.get("action"), "requires_approval": True} + + confirm = define_interrupt( + registry, + "confirm", + description="Requires user approval", + request_metadata=get_meta, + ) + """ + + async def interrupt_wrapper(input: Any) -> Any: # noqa: ANN401 - wire JSON args; never returns (raises Interrupt) + # Interrupt tools accept arbitrary JSON args like any tool. + meta = None + if callable(request_metadata): + meta = request_metadata(input) + elif request_metadata is not None: + meta = request_metadata + raise Interrupt(meta) + + return _define_tool( + registry, + interrupt_wrapper, + name=name, + description=description, + input_schema=input_schema, + ) diff --git a/py/packages/genkit/src/genkit/_core/_action.py b/py/packages/genkit/src/genkit/_core/_action.py index 1d2b2c407f..93d03e8637 100644 --- a/py/packages/genkit/src/genkit/_core/_action.py +++ b/py/packages/genkit/src/genkit/_core/_action.py @@ -35,6 +35,7 @@ from genkit._core._channel import Channel from genkit._core._compat import StrEnum from genkit._core._error import GenkitError +from genkit._core._schema import to_json_schema from genkit._core._trace._path import build_path from genkit._core._trace._suppress import suppress_telemetry from genkit._core._tracing import tracer @@ -347,6 +348,22 @@ def __init__( self._fn = _make_tracing_wrapper(name, kind, span_metadata or {}, n_action_args, fn) self._initialize_io_schemas(action_args, arg_types, resolved_annotations, input_spec) + def _override_input_schema( + self, + input_schema: type[BaseModel] | dict[str, object], + ) -> None: + """Replace input JSON schema (and input validation) when explicitly provided. + + Used when ``metadata_fn`` is loosely typed but the wire contract should be a + Pydantic model or JSON Schema dict. + """ + in_js = to_json_schema(input_schema) + self.input_schema = in_js + if isinstance(input_schema, dict): + self._input_type = None + else: + self._input_type = cast(TypeAdapter[InputT], TypeAdapter(input_schema)) + @property def kind(self) -> ActionKind: return self._kind diff --git a/py/packages/genkit/src/genkit/_core/_flow.py b/py/packages/genkit/src/genkit/_core/_flow.py index 2958832caf..e18371583f 100644 --- a/py/packages/genkit/src/genkit/_core/_flow.py +++ b/py/packages/genkit/src/genkit/_core/_flow.py @@ -67,9 +67,9 @@ def define_flow( """Register an async function as a flow action.""" # All Python functions have __name__, but ty is strict about Callable protocol if not inspect.iscoroutinefunction(func): - raise TypeError(f'Flow must be async: {func.__name__}') # ty: ignore[unresolved-attribute] + raise TypeError(f'Flow must be async: {getattr(func, "__name__", repr(func))}') - flow_name = name or func.__name__ # ty: ignore[unresolved-attribute] + flow_name = name or getattr(func, '__name__', None) or 'unnamed_flow' return registry.register_action( name=flow_name, kind=ActionKind.FLOW, diff --git a/py/packages/genkit/tests/genkit/ai/_tools_test.py b/py/packages/genkit/tests/genkit/ai/_tools_test.py new file mode 100644 index 0000000000..e5228931ba --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/_tools_test.py @@ -0,0 +1,306 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for tool restart builder and run_tool_after_restart.""" + +import pytest + +from genkit import ActionKind, Genkit +from genkit._ai._generate import run_tool_after_restart +from genkit._ai._tools import ( + Interrupt, + ToolRunContext, + _tool_original_input, + _tool_resumed_metadata, + respond_to_interrupt, +) +from genkit._core._error import GenkitError +from genkit._core._typing import ToolRequest, ToolRequestPart, ToolResponsePart + + +@pytest.mark.asyncio +async def test_restart_sets_resumed_metadata_and_preserves_interrupt() -> None: + """``tool.restart``: copy interrupt metadata, set ``resumed``; ``interrupt`` stays on the restart TRP.""" + ai = Genkit() + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: # noqa: ARG001 + return 'ok' + + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='pay', ref='r1', input={'amount': 10}), + metadata={'interrupt': {'reason': 'hold'}}, + ) + out = pay.restart(None, interrupt=interrupt_trp, resumed_metadata={'k': 'v'}) + assert isinstance(out, ToolRequestPart) + assert out.metadata is not None + assert out.metadata.get('resumed') == {'k': 'v'} + assert out.metadata.get('interrupt') == {'reason': 'hold'} + assert out.tool_request.input == {'amount': 10} + + +@pytest.mark.asyncio +async def test_restart_replace_input_sets_replaced_input() -> None: + """Restart with new input sets ``replacedInput`` to prior input and updates ``tool_request.input``.""" + ai = Genkit() + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: # noqa: ARG001 + return 'ok' + + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='pay', ref='r1', input={'amount': 10}), + metadata={'interrupt': True}, + ) + out = pay.restart({'amount': 99}, interrupt=interrupt_trp, resumed_metadata={'by': 'u'}) + assert isinstance(out, ToolRequestPart) + assert out.metadata is not None + assert out.metadata.get('replacedInput') == {'amount': 10} + assert out.tool_request.input == {'amount': 99} + assert out.metadata.get('resumed') == {'by': 'u'} + assert out.metadata.get('interrupt') is True + + +@pytest.mark.asyncio +async def test_restart_resumed_defaults_to_true() -> None: + """When ``resumed_metadata=None``, restart TRP sets ``metadata.resumed`` to True.""" + ai = Genkit() + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: # noqa: ARG001 + return 'ok' + + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='pay', ref='r1', input={}), + metadata={'interrupt': True}, + ) + out = pay.restart(None, interrupt=interrupt_trp, resumed_metadata=None) + assert isinstance(out, ToolRequestPart) + assert out.metadata is not None + assert out.metadata.get('resumed') is True + assert out.metadata.get('interrupt') is True + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_resumed_true_maps_to_empty_dict_in_context() -> None: + """``run_tool_after_restart``: ``metadata.resumed is True`` → ``ToolRunContext.resumed_metadata`` is ``{}``.""" + ai = Genkit() + captured: list[tuple[dict | None, object | None]] = [] + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + captured.append((ctx.resumed_metadata, ctx.original_input)) + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={'q': 1}), + metadata={'resumed': True}, + ) + await run_tool_after_restart(action, restart_trp) + assert len(captured) == 1 + assert captured[0][0] == {} + assert captured[0][1] is None + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_resumed_dict() -> None: + """Restart TRP with ``metadata.resumed`` dict is passed through to ``ToolRunContext.resumed_metadata``.""" + ai = Genkit() + captured: list[dict | None] = [] + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + captured.append(ctx.resumed_metadata) + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={}), + metadata={'resumed': {'by': 'x'}}, + ) + await run_tool_after_restart(action, restart_trp) + assert captured == [{'by': 'x'}] + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_replaced_input() -> None: + """``replacedInput`` on TRP sets tool input from current request and ``original_input`` from prior.""" + ai = Genkit() + captured: list[tuple[object, object | None]] = [] + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + captured.append((inp, ctx.original_input)) + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={'new': True}), + metadata={'resumed': True, 'replacedInput': {'old': True}}, + ) + await run_tool_after_restart(action, restart_trp) + assert len(captured) == 1 + assert captured[0][0] == {'new': True} + assert captured[0][1] == {'old': True} + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_resets_contextvars() -> None: + """After ``run_tool_after_restart`` returns, resume ContextVars are cleared (no leak between runs).""" + ai = Genkit() + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={}), + metadata={'resumed': True}, + ) + await run_tool_after_restart(action, restart_trp) + assert _tool_resumed_metadata.get() is None + assert _tool_original_input.get() is None + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_nested_interrupt_raises() -> None: + """Tool raising ``Interrupt`` during a restart run raises ``GenkitError`` (nested interrupt unsupported).""" + ai = Genkit() + + @ai.tool(name='t2') + async def t2(inp: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + raise Interrupt() + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t2') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t2', ref='x', input={}), + metadata={'resumed': True}, + ) + with pytest.raises(GenkitError) as ei: + await run_tool_after_restart(action, restart_trp) + assert ei.value.status == 'FAILED_PRECONDITION' + assert 'interrupted again' in ei.value.original_message.lower() + + +# --------------------------------------------------------------------------- +# Wire-format tests: respond_to_interrupt +# --------------------------------------------------------------------------- + + +def test_respond_to_interrupt_wire_format_basic() -> None: + """respond_to_interrupt produces a ToolResponsePart with matching ref/name and interruptResponse metadata.""" + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='ask_user', ref='ref-abc', input={'question': 'ok?'}), + metadata={'interrupt': {'reason': 'needs_approval'}}, + ) + + result = respond_to_interrupt('yes', interrupt=interrupt_trp) + + assert isinstance(result, ToolResponsePart) + assert result.tool_response.name == 'ask_user' + assert result.tool_response.ref == 'ref-abc' + assert result.tool_response.output == 'yes' + assert result.metadata is not None + assert result.metadata.get('interruptResponse') is True + + +def test_respond_to_interrupt_wire_format_with_metadata() -> None: + """respond_to_interrupt attaches custom metadata under interruptResponse key.""" + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='confirm', ref='ref-xyz', input={}), + metadata={'interrupt': True}, + ) + + result = respond_to_interrupt({'approved': True}, interrupt=interrupt_trp, metadata={'by': 'admin'}) + + assert result.tool_response.ref == 'ref-xyz' + assert result.tool_response.output == {'approved': True} + assert result.metadata is not None + assert result.metadata.get('interruptResponse') == {'by': 'admin'} + + +def test_restart_preserves_ref_on_wire() -> None: + """restart() preserves the original tool_request.ref so the resumed TRP can be correlated.""" + ai = Genkit() + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: # noqa: ARG001 + return 'ok' + + interrupt_trp = ToolRequestPart( + tool_request=ToolRequest(name='pay', ref='corr-id-1', input={'amount': 50}), + metadata={'interrupt': True}, + ) + out = pay.restart(None, interrupt=interrupt_trp) + + assert out.tool_request.ref == 'corr-id-1' + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_response_preserves_ref() -> None: + """run_tool_after_restart produces a ToolResponsePart whose ref matches the restart TRP's ref.""" + ai = Genkit() + + @ai.tool(name='t_ref') + async def t_ref(inp: dict) -> str: # noqa: ARG001 + return 'done' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='t_ref') + assert action is not None + + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='t_ref', ref='wire-ref-99', input={}), + metadata={'resumed': True}, + ) + part = await run_tool_after_restart(action, restart_trp) + assert part.tool_response.ref == 'wire-ref-99' + + +@pytest.mark.asyncio +async def test_run_tool_after_restart_response_preserves_ref_and_uses_new_input() -> None: + """``run_tool_after_restart`` returns a ToolResponsePart whose ref matches the restart TRP; + ``tool_request.input`` is what ``tool.run`` receives, and ``metadata.replacedInput`` is + ``ToolRunContext.original_input`` (prior interrupted input). + """ + ai = Genkit() + received_inputs: list[dict] = [] + original_inputs: list[object | None] = [] + + @ai.tool(name='transfer') + async def transfer(inp: dict, ctx: ToolRunContext) -> str: + received_inputs.append(dict(inp)) + original_inputs.append(ctx.original_input) + if not inp.get('confirmed'): + raise Interrupt({'reason': 'needs_approval'}) + return f'transferred {inp.get("amount")}' + + action = await ai.registry.resolve_action(kind=ActionKind.TOOL, name='transfer') + assert action is not None + + prior = {'amount': 100, 'confirmed': False} + # Simulate a restart TRP: original input had confirmed=False, new input has confirmed=True. + restart_trp = ToolRequestPart( + tool_request=ToolRequest(name='transfer', ref='ref-42', input={'amount': 100, 'confirmed': True}), + metadata={'resumed': True, 'replacedInput': prior}, + ) + result = await run_tool_after_restart(action, restart_trp) + + # Ref is preserved from the restart TRP. + assert result.tool_response.ref == 'ref-42' + assert result.tool_response.name == 'transfer' + # Primary arg is current tool_request.input; replacedInput is surfaced as original_input. + assert received_inputs == [{'amount': 100, 'confirmed': True}] + assert original_inputs == [prior] + assert result.tool_response.output == 'transferred 100' diff --git a/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py b/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py new file mode 100644 index 0000000000..60c9ada561 --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/generate_helpers_test.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for private helpers in genkit._ai._generate (interrupt / resume).""" + +from genkit._ai._generate import ( + _find_corresponding_restart, + _find_corresponding_tool_response, + _interrupt_from_tool_exc, + _to_pending_response, +) +from genkit._ai._tools import Interrupt +from genkit._core._error import GenkitError +from genkit._core._typing import ToolRequest, ToolRequestPart, ToolResponse, ToolResponsePart + + +def test_find_corresponding_restart_matches_name_and_ref() -> None: + """``_find_corresponding_restart`` picks the resume TRP whose name+ref match the pending TRP; else None.""" + pending = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r1', input={}), + ) + match = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r1', input={'new': True}), + metadata={'resumed': True}, + ) + other_ref = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r2', input={}), + metadata={'resumed': True}, + ) + other_name = ToolRequestPart( + tool_request=ToolRequest(name='u', ref='r1', input={}), + metadata={'resumed': True}, + ) + + assert _find_corresponding_restart([match], pending) is match + assert _find_corresponding_restart([other_ref, match], pending) is match + assert _find_corresponding_restart([other_ref], pending) is None + assert _find_corresponding_restart([other_name], pending) is None + assert _find_corresponding_restart(None, pending) is None + assert _find_corresponding_restart([], pending) is None + + +def test_find_corresponding_tool_response_matches_name_and_ref() -> None: + """``_find_corresponding_tool_response`` matches ``ToolResponsePart`` to pending TRP by name+ref.""" + pending = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r1', input={}), + ) + trp = ToolResponsePart( + tool_response=ToolResponse(name='t', ref='r1', output=42), + ) + other = ToolResponsePart( + tool_response=ToolResponse(name='t', ref='r2', output=0), + ) + + got = _find_corresponding_tool_response([trp], pending) + assert got is not None + assert got == trp + + assert _find_corresponding_tool_response([other], pending) is None + assert _find_corresponding_tool_response([], pending) is None + + +def test_interrupt_from_tool_exc() -> None: + """``_interrupt_from_tool_exc`` unwraps bare ``Interrupt`` or ``GenkitError.cause``; else None.""" + intr = Interrupt({'x': 1}) + assert _interrupt_from_tool_exc(intr) is intr + + wrapped = GenkitError(message='x', cause=intr) + assert _interrupt_from_tool_exc(wrapped) is intr + + assert _interrupt_from_tool_exc(ValueError('x')) is None + + +def test_to_pending_response_sets_pending_output() -> None: + """``_to_pending_response`` merges prior TRP metadata with ``pendingOutput`` from the tool response.""" + req = ToolRequestPart( + tool_request=ToolRequest(name='t', ref='r1', input={'a': 1}), + metadata={'interrupt': {'old': True}}, + ) + resp = ToolResponsePart( + tool_response=ToolResponse(name='t', ref='r1', output={'out': 2}), + ) + part = _to_pending_response(req, resp) + root = part.root + assert isinstance(root, ToolRequestPart) + assert root.tool_request.name == 't' + assert root.tool_request.ref == 'r1' + assert root.metadata is not None + assert root.metadata.get('pendingOutput') == {'out': 2} + assert root.metadata.get('interrupt') == {'old': True} diff --git a/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py new file mode 100644 index 0000000000..26bcfe5af8 --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/generate_interrupt_resume_test.py @@ -0,0 +1,762 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for interrupt, resume, and restart behavior in ``generate_action``.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from genkit import Genkit, Message, ModelResponse +from genkit._ai._generate import generate_action +from genkit._ai._testing import define_programmable_model +from genkit._ai._tools import Interrupt, ToolRunContext, respond_to_interrupt +from genkit._core._error import GenkitError +from genkit._core._model import GenerateActionOptions +from genkit._core._typing import FinishReason, Resume + + +def _wire(messages: list[Message]) -> list[dict[str, Any]]: + """Messages as JSON-shaped dicts (``model_dump`` with aliases) for comparing to expected wire.""" + return [m.model_dump(mode='json', exclude_none=True, by_alias=True) for m in messages] + + +def _gen_opts( + ai: Genkit, *, tools: list[str], messages: list[Message], resume: Resume | None = None +) -> GenerateActionOptions: + return GenerateActionOptions( + model='programmableModel', + messages=messages, + tools=tools, + resume=resume, + ) + + +@pytest.mark.asyncio +async def test_normal_two_arg_tools_see_no_resume_context() -> None: + """Two tools in one batch, no interrupt: ``ToolRunContext`` should stay empty of resume fields. + + The model asks for both tools in one turn; each tool records whether it thinks it's a resume + (it shouldn't), and we compare the whole conversation to the expected wire. + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + seen: list[tuple[bool, object | None, object | None]] = [] + + @ai.tool(name='u1') + async def u1(_: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + seen.append((ctx.is_resumed(), ctx.resumed_metadata, ctx.original_input)) + return 'a' + + @ai.tool(name='u2') + async def u2(_: dict, ctx: ToolRunContext) -> str: # noqa: ARG001 + seen.append((ctx.is_resumed(), ctx.resumed_metadata, ctx.original_input)) + return 'b' + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'go'}, + {'toolRequest': {'ref': '1', 'name': 'u1', 'input': {}}}, + {'toolRequest': {'ref': '2', 'name': 'u2', 'input': {}}}, + ], + }), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'done'}]}), + ) + ) + + r = await generate_action( + ai.registry, + _gen_opts( + ai, tools=['u1', 'u2'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})] + ), + ) + assert seen == [(False, None, None), (False, None, None)] + assert _wire(r.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'go'}, + {'toolRequest': {'ref': '1', 'name': 'u1', 'input': {}}}, + {'toolRequest': {'ref': '2', 'name': 'u2', 'input': {}}}, + ], + }, + { + 'role': 'tool', + 'content': [ + {'toolResponse': {'ref': '1', 'name': 'u1', 'output': 'a'}}, + {'toolResponse': {'ref': '2', 'name': 'u2', 'output': 'b'}}, + ], + }, + { + 'role': 'model', + 'content': [{'text': 'done'}], + }, + ] + + +@pytest.mark.asyncio +async def test_interrupt_wires_trp_metadata_interrupt_and_stops() -> None: + """When the tool raises ``Interrupt``, the interrupt payload lands on the TRP metadata, finish is + ``INTERRUPTED``, and we never get a ``role=tool`` row yet—only user + model in the history. + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + + @ai.tool(name='intr') + async def intr(_: dict) -> str: # noqa: ARG001 + raise Interrupt({'reason': 'x'}) + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'call'}, + {'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}}, + ], + }), + ) + ) + + r = await generate_action( + ai.registry, + _gen_opts(ai, tools=['intr'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + ) + assert r.finish_reason == FinishReason.INTERRUPTED + assert _wire(r.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'call'}, + { + 'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}, + 'metadata': {'interrupt': {'reason': 'x'}}, + }, + ], + }, + ] + + +@pytest.mark.asyncio +async def test_resume_respond_trp_gets_resolved_interrupt_and_tool_trp() -> None: + """Follow-up generate with ``Resume(respond=[...])``: the stuck TRP picks up ``resolvedInterrupt``, + the tool reply shows up under ``interruptResponse``, and the model can answer again. Compares + wire before and after the interrupt. + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + + @ai.tool(name='intr') + async def intr(_: dict) -> str: # noqa: ARG001 + raise Interrupt({'reason': 'x'}) + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'call'}, + {'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}}, + ], + }), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'after resume'}]}), + ) + ) + + first = await generate_action( + ai.registry, + _gen_opts(ai, tools=['intr'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + ) + assert first.finish_reason == FinishReason.INTERRUPTED + assert _wire(first.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'call'}, + { + 'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}, + 'metadata': {'interrupt': {'reason': 'x'}}, + }, + ], + }, + ] + + reply = respond_to_interrupt({'bar': 2}, interrupt=first.interrupts[0]) + + second = await generate_action( + ai.registry, + _gen_opts(ai, tools=['intr'], messages=list(first.messages), resume=Resume(respond=[reply])), + ) + + assert second.finish_reason == FinishReason.STOP + assert _wire(second.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'call'}, + { + 'toolRequest': {'ref': 'r1', 'name': 'intr', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'reason': 'x'}}, + }, + ], + }, + { + 'role': 'tool', + 'content': [ + { + 'toolResponse': {'ref': 'r1', 'name': 'intr', 'output': {'bar': 2}}, + 'metadata': {'interruptResponse': True}, + }, + ], + 'metadata': {'resumed': True}, + }, + { + 'role': 'model', + 'content': [{'text': 'after resume'}], + }, + ] + + +@pytest.mark.asyncio +async def test_tool_either_interrupts_or_returns() -> None: + """Same tool, two independent generate calls with different inputs. + + First call: ``preapproved=False`` → tool raises Interrupt → finish is INTERRUPTED, no tool row. + Second call: ``preapproved=True`` → tool returns 42 → finish is STOP, full tool+model rows present. + Both results are wire-asserted in full. + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + + @ai.tool(name='bank_transfer') + async def bank_transfer(inp: dict) -> int: + if not inp.get('preapproved'): + raise Interrupt({'reason': 'awaiting_approval'}) + return 42 + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 't'}, + { + 'toolRequest': { + 'ref': 'g', + 'name': 'bank_transfer', + 'input': {'preapproved': False}, + }, + }, + ], + }), + ) + ) + r_fail = await generate_action( + ai.registry, + _gen_opts( + ai, + tools=['bank_transfer'], + messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})], + ), + ) + assert r_fail.finish_reason == FinishReason.INTERRUPTED + assert _wire(r_fail.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 't'}, + { + 'toolRequest': { + 'ref': 'g', + 'name': 'bank_transfer', + 'input': {'preapproved': False}, + }, + 'metadata': {'interrupt': {'reason': 'awaiting_approval'}}, + }, + ], + }, + ] + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 't2'}, + { + 'toolRequest': { + 'ref': 'g2', + 'name': 'bank_transfer', + 'input': {'preapproved': True}, + }, + }, + ], + }), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'ok'}]}), + ) + ) + r_ok = await generate_action( + ai.registry, + _gen_opts( + ai, + tools=['bank_transfer'], + messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})], + ), + ) + assert r_ok.finish_reason == FinishReason.STOP + assert _wire(r_ok.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 't2'}, + { + 'toolRequest': { + 'ref': 'g2', + 'name': 'bank_transfer', + 'input': {'preapproved': True}, + }, + }, + ], + }, + { + 'role': 'tool', + 'content': [ + {'toolResponse': {'ref': 'g2', 'name': 'bank_transfer', 'output': 42}}, + ], + }, + { + 'role': 'model', + 'content': [{'text': 'ok'}], + }, + ] + + +@pytest.mark.asyncio +async def test_resume_restart_runs_tool_second_time_and_resolved_interrupt_on_model() -> None: + """``Resume(restart=[...])`` reruns the tool with new input after an interrupt. The tool runs + twice (tracked in ``calls``); the second pass shows ``resolvedInterrupt`` on the model TRP and a + plain ``toolResponse`` (no ``interruptResponse`` on that path). + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + calls: list[str] = [] + + @ai.tool(name='pay') + async def pay(inp: dict) -> str: + calls.append('run') + if not inp.get('ok'): + raise Interrupt({'hold': True}) + return 'paid' + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'x'}, + {'toolRequest': {'ref': 'p1', 'name': 'pay', 'input': {}}}, + ], + }), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'final'}]}), + ) + ) + # ^ Queued for the second generate call (after restart re-runs the tool). + + first = await generate_action( + ai.registry, + _gen_opts(ai, tools=['pay'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})]), + ) + assert first.finish_reason == FinishReason.INTERRUPTED + assert _wire(first.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'x'}, + { + 'toolRequest': {'ref': 'p1', 'name': 'pay', 'input': {}}, + 'metadata': {'interrupt': {'hold': True}}, + }, + ], + }, + ] + + restart_trp = pay.restart({'ok': True}, interrupt=first.interrupts[0], resumed_metadata={'by': 'test'}) + + second = await generate_action( + ai.registry, + _gen_opts(ai, tools=['pay'], messages=list(first.messages), resume=Resume(restart=[restart_trp])), + ) + + assert second.finish_reason == FinishReason.STOP + assert calls == ['run', 'run'] + assert _wire(second.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'x'}, + { + 'toolRequest': {'ref': 'p1', 'name': 'pay', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'hold': True}}, + }, + ], + }, + { + 'role': 'tool', + 'content': [ + {'toolResponse': {'ref': 'p1', 'name': 'pay', 'output': 'paid'}}, + ], + 'metadata': {'resumed': True}, + }, + { + 'role': 'model', + 'content': [{'text': 'final'}], + }, + ] + + +@pytest.mark.asyncio +async def test_mixed_resume_one_respond_one_restart() -> None: + """Two tool calls both interrupt in one turn; the next generate fills in a ``respond`` for one + ref and a ``restart`` for the other. Expect one tool message with two parts (respond path still + has ``interruptResponse``; restart path does not). + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + + @ai.tool(name='a') + async def a_tool(_: dict) -> str: # noqa: ARG001 + raise Interrupt({'tool': 'a'}) + + @ai.tool(name='b') + async def b_tool(inp: dict) -> str: + if inp.get('ok'): + return 'b-done' + raise Interrupt({'tool': 'b'}) + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'text': 'both'}, + {'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}}, + {'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}}, + ], + }), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'end'}]}), + ) + ) + + first = await generate_action( + ai.registry, + _gen_opts( + ai, tools=['a', 'b'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})] + ), + ) + assert first.finish_reason == FinishReason.INTERRUPTED + assert _wire(first.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'both'}, + { + 'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}, + 'metadata': {'interrupt': {'tool': 'a'}}, + }, + { + 'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}, + 'metadata': {'interrupt': {'tool': 'b'}}, + }, + ], + }, + ] + + ia = next(p for p in first.interrupts if p.tool_request.name == 'a') + ib = next(p for p in first.interrupts if p.tool_request.name == 'b') + + second = await generate_action( + ai.registry, + _gen_opts( + ai, + tools=['a', 'b'], + messages=list(first.messages), + resume=Resume( + respond=[respond_to_interrupt({'done': True}, interrupt=ia)], + restart=[b_tool.restart({'ok': True}, interrupt=ib, resumed_metadata=None)], + ), + ), + ) + + assert second.finish_reason == FinishReason.STOP + assert _wire(second.messages) == [ + { + 'role': 'user', + 'content': [{'text': 'hi'}], + }, + { + 'role': 'model', + 'content': [ + {'text': 'both'}, + { + 'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'tool': 'a'}}, + }, + { + 'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'tool': 'b'}}, + }, + ], + }, + { + 'role': 'tool', + 'content': [ + { + 'toolResponse': {'ref': 'ra', 'name': 'a', 'output': {'done': True}}, + 'metadata': {'interruptResponse': True}, + }, + # Restart path: tool re-runs and returns its own output; no interruptResponse metadata. + {'toolResponse': {'ref': 'rb', 'name': 'b', 'output': 'b-done'}}, + ], + 'metadata': {'resumed': True}, + }, + { + 'role': 'model', + 'content': [{'text': 'end'}], + }, + ] + + +@pytest.mark.asyncio +async def test_mixed_one_interrupts_one_succeeds_pending_output_in_wire() -> None: + """Two tools in one turn: ``a`` interrupts, ``b`` succeeds. + + Turn 1: both run in parallel. ``b``'s output is stashed as ``pendingOutput`` + on its TRP in the model message (no tool message yet). finish=INTERRUPTED. + + Turn 2: resume with ``respond=[...]`` for ``a`` only — no action needed for + ``b``. The framework reconstructs ``b``'s tool response from the stashed + output, strips ``pendingOutput`` from ``b``'s model TRP, and + marks the tool response ``source: pending`` on the wire. + """ + ai = Genkit() + pm, _ = define_programmable_model(ai) + + @ai.tool(name='a') + async def a_tool(_: dict) -> str: # noqa: ARG001 + raise Interrupt({'reason': 'needs_approval'}) + + @ai.tool(name='b') + async def b_tool(_: dict) -> int: # noqa: ARG001 + return 42 + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({ + 'role': 'model', + 'content': [ + {'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}}, + {'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}}, + ], + }), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message.model_validate({'role': 'model', 'content': [{'text': 'done'}]}), + ) + ) + + first = await generate_action( + ai.registry, + _gen_opts( + ai, tools=['a', 'b'], messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]})] + ), + ) + assert first.finish_reason == FinishReason.INTERRUPTED + # b's output is stashed in pendingOutput on its TRP; no tool message yet. + assert _wire(first.messages) == [ + {'role': 'user', 'content': [{'text': 'hi'}]}, + { + 'role': 'model', + 'content': [ + { + 'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}, + 'metadata': {'interrupt': {'reason': 'needs_approval'}}, + }, + { + 'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}, + 'metadata': {'pendingOutput': 42}, + }, + ], + }, + ] + + ia = first.interrupts[0] + second = await generate_action( + ai.registry, + _gen_opts( + ai, + tools=['a', 'b'], + messages=list(first.messages), + resume=Resume(respond=[respond_to_interrupt({'approved': True}, interrupt=ia)]), + ), + ) + + assert second.finish_reason == FinishReason.STOP + assert _wire(second.messages) == [ + {'role': 'user', 'content': [{'text': 'hi'}]}, + { + 'role': 'model', + 'content': [ + { + 'toolRequest': {'ref': 'ra', 'name': 'a', 'input': {}}, + 'metadata': {'resolvedInterrupt': {'reason': 'needs_approval'}}, + }, + {'toolRequest': {'ref': 'rb', 'name': 'b', 'input': {}}}, + ], + }, + { + 'role': 'tool', + 'content': [ + { + 'toolResponse': {'ref': 'ra', 'name': 'a', 'output': {'approved': True}}, + 'metadata': {'interruptResponse': True}, + }, + { + # b ran on turn 1; output reconstructed from pendingOutput stash. + 'toolResponse': {'ref': 'rb', 'name': 'b', 'output': 42}, + 'metadata': {'source': 'pending'}, + }, + ], + 'metadata': {'resumed': True}, + }, + {'role': 'model', 'content': [{'text': 'done'}]}, + ] + + +@pytest.mark.asyncio +async def test_resume_without_matching_replies_raises() -> None: + """Hand-built history with an interrupted TRP but an empty ``Resume()``: expect ``GenkitError`` + and a message that mentions replies or restarts. + """ + ai = Genkit() + _, _ = define_programmable_model(ai) + + with pytest.raises(GenkitError) as ei: + await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[ + Message.model_validate({'role': 'user', 'content': [{'text': 'hi'}]}), + Message.model_validate({ + 'role': 'model', + 'content': [ + { + 'toolRequest': {'ref': 'z', 'name': 'missing', 'input': {}}, + 'metadata': {'interrupt': True}, + }, + ], + }), + ], + resume=Resume(), + ), + ) + assert ei.value.status == 'INVALID_ARGUMENT' + assert 'unresolved tool request' in ei.value.original_message.lower() + + +@pytest.mark.asyncio +async def test_resume_requires_last_message_model_with_tool_requests() -> None: + """Can't resume when the transcript ends on a user turn: ``GenkitError``, and the message should + mention needing a model message. + """ + ai = Genkit() + _, _ = define_programmable_model(ai) + + with pytest.raises(GenkitError) as ei: + await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message.model_validate({'role': 'user', 'content': [{'text': 'only user'}]})], + resume=Resume(), + ), + ) + assert ei.value.status == 'FAILED_PRECONDITION' + assert "cannot 'resume'" in ei.value.original_message.lower() diff --git a/py/packages/genkit/tests/genkit/ai/generate_test.py b/py/packages/genkit/tests/genkit/ai/generate_test.py index 0a550c9093..ecb05be823 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_test.py @@ -22,6 +22,7 @@ define_echo_model, define_programmable_model, ) +from genkit._ai._tools import Interrupt from genkit._core._action import ActionRunContext from genkit._core._model import GenerateActionOptions, ModelRequest from genkit._core._typing import ( @@ -30,6 +31,8 @@ Part, Role, TextPart, + ToolRequest, + ToolRequestPart, ) @@ -370,6 +373,79 @@ def collect_chunks(c: ModelResponseChunk) -> None: ] +@pytest.mark.asyncio +async def test_parallel_tool_requests_one_interrupt_keeps_pending_output_for_others( + setup_test: tuple[Genkit, ProgrammableModel], +) -> None: + """With asyncio.gather in resolve_tool_requests: one interrupt still records pendingOutput for others.""" + ai, pm = setup_test + + @ai.tool(name='tool_a') + async def tool_a() -> str: + return 'a_ok' + + @ai.tool(name='tool_b') + async def tool_b() -> None: + raise Interrupt({'stop': True}) + + @ai.tool(name='tool_c') + async def tool_c() -> str: + return 'c_ok' + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message( + role=Role.MODEL, + content=[ + Part(TextPart(text='call three')), + Part( + root=ToolRequestPart( + tool_request=ToolRequest(name='tool_a', ref='ref-a', input={}), + ) + ), + Part( + root=ToolRequestPart( + tool_request=ToolRequest(name='tool_b', ref='ref-b', input={}), + ) + ), + Part( + root=ToolRequestPart( + tool_request=ToolRequest(name='tool_c', ref='ref-c', input={}), + ) + ), + ], + ), + ) + ) + + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[ + Message(role=Role.USER, content=[Part(TextPart(text='hi'))]), + ], + tools=['tool_a', 'tool_b', 'tool_c'], + ), + ) + + assert response.finish_reason == FinishReason.INTERRUPTED + assert response.message is not None + parts = response.message.content + assert len(parts) == 4 + assert parts[0].root == TextPart(text='call three') + a_root = parts[1].root + b_root = parts[2].root + c_root = parts[3].root + assert isinstance(a_root, ToolRequestPart) + assert isinstance(b_root, ToolRequestPart) + assert isinstance(c_root, ToolRequestPart) + assert a_root.metadata and a_root.metadata.get('pendingOutput') == 'a_ok' + assert b_root.metadata and b_root.metadata.get('interrupt') == {'stop': True} + assert c_root.metadata and c_root.metadata.get('pendingOutput') == 'c_ok' + + ########################################################################## # run tests from /tests/specs/generate.yaml ########################################################################## diff --git a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py index c47a4e6cd0..6b6cbe0474 100644 --- a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py +++ b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py @@ -13,7 +13,7 @@ from opentelemetry.sdk.trace import TracerProvider from genkit import Genkit -from genkit._core._action import Action, ActionKind, _action_context +from genkit._core._action import _action_context from genkit._core._typing import Operation @@ -40,28 +40,6 @@ def sync_fn() -> str: await ai.run(name='test3', fn=sync_fn) # type: ignore[arg-type] -@pytest.mark.asyncio -async def test_genkit_dynamic_tool() -> None: - """Test Genkit.dynamic_tool method.""" - ai = Genkit() - - async def my_tool(x: int) -> int: - return x + 1 - - tool = ai.dynamic_tool(name='my_tool', fn=my_tool, description='increment x') - - assert isinstance(tool, Action) - assert tool.kind == ActionKind.TOOL - assert tool.name == 'my_tool' - assert tool.description == 'increment x' - assert tool.metadata.get('type') == 'tool' - assert tool.metadata.get('dynamic') is True - - # Execution - resp = await tool.run(5) - assert resp.response == 6 - - @pytest.mark.asyncio async def test_genkit_check_operation() -> None: """Test Genkit.check_operation method.""" diff --git a/py/packages/genkit/tests/genkit/ai/prompt_test.py b/py/packages/genkit/tests/genkit/ai/prompt_test.py index a3ecd5e49f..4ea4194580 100644 --- a/py/packages/genkit/tests/genkit/ai/prompt_test.py +++ b/py/packages/genkit/tests/genkit/ai/prompt_test.py @@ -86,12 +86,12 @@ async def test_simple_prompt_with_override_config() -> None: my_prompt = ai.define_prompt(prompt='hi', config={'banana': True}) # New API: pass config via opts parameter - this MERGES with prompt config - response = await my_prompt(opts={'config': {'temperature': 12}}) + response = await my_prompt(config={'temperature': 12}) assert response.text == want_txt # New API: stream also uses opts - result = my_prompt.stream(opts={'config': {'temperature': 12}}) + result = my_prompt.stream(config={'temperature': 12}) assert (await result.response).text == want_txt @@ -244,7 +244,7 @@ async def test_prompt_rendering_dotprompt( my_prompt = ai.define_prompt(**prompt) # New API: use opts parameter to pass config and context - response = await my_prompt(input, opts={'config': input_option, 'context': context}) + response = await my_prompt(input, config=input_option, context=context) assert response.text == want_rendered @@ -488,7 +488,7 @@ async def test_config_merge_priority() -> None: # New API: runtime config is MERGED with prompt config # - temperature: 0.9 (from opts, overrides 0.5) # - banana: 'yellow' (from prompt, preserved) - rendered = await my_prompt.render(opts={'config': {'temperature': 0.9}}) + rendered = await my_prompt.render(config={'temperature': 0.9}) assert rendered.config is not None # Config is now a dict after merging @@ -510,7 +510,7 @@ async def test_opts_can_override_model() -> None: ) # Override model via opts - response = await my_prompt(opts={'model': 'programmableModel'}) + response = await my_prompt(model='programmableModel') # Should use programmableModel, not echoModel assert response.text == 'pm response' @@ -532,7 +532,7 @@ async def test_opts_can_append_messages() -> None: ] # Append conversation history via opts - rendered = await my_prompt.render(opts={'messages': history_messages}) + rendered = await my_prompt.render(messages=history_messages) # Should have: system + history (2) + user prompt = 4 messages assert len(rendered.messages) == 4 @@ -583,13 +583,11 @@ class OutputSchema(BaseModel): output_format='text', # Default to text ) - # Override output via opts + # Override output via kwargs rendered = await my_prompt.render( - opts={ - 'output': { - 'format': 'json', - 'schema': OutputSchema, - } + output={ + 'format': 'json', + 'schema': OutputSchema, } ) @@ -599,6 +597,43 @@ class OutputSchema(BaseModel): assert rendered.output.json_schema is not None +@pytest.mark.asyncio +async def test_executable_prompt_opts_removed() -> None: + """opts= has been removed; pass options as explicit kwargs (e.g. model=).""" + ai, *_ = setup_test() + + my_prompt = ai.define_prompt(prompt='hi', output_format='text') + + with pytest.raises(TypeError, match='opts'): + # Invalid kwarg on purpose; static checkers need a hint (runtime rejects with TypeError). + await my_prompt(opts={'model': 'echoModel'}) # pyrefly: ignore[unexpected-keyword] # pyright: ignore + + +@pytest.mark.asyncio +async def test_executable_prompt_input_positional_opts_as_kwargs() -> None: + """ExecutablePrompt: input is positional, opts via kwargs after *.""" + ai, *_ = setup_test() + + my_prompt = ai.define_prompt( + prompt='Recipe for {{cuisine}} {{dish}}', + output_format='text', + ) + + # input = positional (template vars), output = kwarg (opts) + rendered = await my_prompt.render( + {'cuisine': 'Italian', 'dish': 'pasta'}, + output={'format': 'text'}, + ) + + # Template vars from input should be in the rendered prompt + assert any('Italian' in str(m) for m in rendered.messages) + assert any('pasta' in str(m) for m in rendered.messages) + + # output kwarg should be respected + assert rendered.output is not None + assert rendered.output.format == 'text' + + # Tests for file-based prompt loading and two-action structure @pytest.mark.asyncio async def test_file_based_prompt_registers_two_actions() -> None: diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py index 0310c779f2..f3e08af8e9 100644 --- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py @@ -15,11 +15,11 @@ from genkit import ( Document, Genkit, + Interrupt, Message, ModelResponse, ModelResponseChunk, - ToolRunContext, - tool_response, + respond_to_interrupt, ) from genkit._ai._formats._types import FormatDef, Formatter, FormatterConfig from genkit._ai._model import text_from_message @@ -390,9 +390,9 @@ async def test_tool(input: ToolInput) -> int: return (input.value or 0) + 7 @ai.tool(name='test_interrupt') - async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: + async def test_interrupt(input: ToolInput) -> None: """The interrupt.""" - ctx.interrupt({'banana': 'yes please'}) + raise Interrupt({'banana': 'yes please'}) tool_request_msg = Message( Message( @@ -503,9 +503,9 @@ async def test_tool(input: ToolInput) -> int: return (input.value or 0) + 7 @ai.tool(name='test_interrupt') - async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: + async def test_interrupt(input: ToolInput) -> None: """The interrupt.""" - ctx.interrupt({'banana': 'yes please'}) + raise Interrupt({'banana': 'yes please'}) tool_request_msg = Message( Message( @@ -556,11 +556,11 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: assert interrupted_response.messages == [ Message( - role='user', + role=Role.USER, content=[Part(root=TextPart(text='hi'))], ), Message( - role='model', + role=Role.MODEL, content=[ Part(root=TextPart(text='call these tools')), Part( @@ -579,10 +579,12 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: ), ] + respond_wrapped = respond_to_interrupt({'bar': 2}, interrupt=interrupted_response.interrupts[0]) + assert isinstance(respond_wrapped, ToolResponsePart) response = await ai.generate( model='programmableModel', messages=interrupted_response.messages, - tool_responses=[tool_response(interrupted_response.interrupts[0], {'bar': 2})], + resume_respond=[respond_wrapped], tools=['test_tool', 'test_interrupt'], ) @@ -590,11 +592,11 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: assert response.messages == [ Message( - role='user', + role=Role.USER, content=[Part(root=TextPart(text='hi'))], ), Message( - role='model', + role=Role.MODEL, content=[ Part(root=TextPart(text='call these tools')), Part( @@ -606,14 +608,14 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: Part( root=ToolRequestPart( tool_request=ToolRequest(ref='234', name='test_tool', input={'value': 5}), - metadata={'pendingOutput': 12}, + metadata=None, ) ), ], metadata=None, ), Message( - role='tool', + role=Role.TOOL, content=[ Part( root=ToolResponsePart( @@ -631,7 +633,7 @@ async def test_interrupt(input: ToolInput, ctx: ToolRunContext) -> None: metadata={'resumed': True}, ), Message( - role='model', + role=Role.MODEL, content=[Part(root=TextPart(text='tool called'))], metadata=None, ), diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index 08bdb27947..0f97947aca 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -1076,10 +1076,18 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool: Genai tool compatible with Gemini API. """ params = self._convert_schema_property(tool.input_schema) - # Fix for no-arg tools: parameters cannot be None if we want the tool to be callable? - # Actually Google GenAI expects type=OBJECT for params usually. + # Empty params: Gemini requires type=OBJECT even for no-arg tools. if not params: params = genai_types.Schema(type=genai_types.Type.OBJECT, properties={}) + # Gemini rejects scalar/array root schemas. Tool params must be OBJECT with + # properties. LLMs always send {"key": value} — wrap scalar/array in {"value": }. + elif self._is_scalar_or_array_root(params): + params = genai_types.Schema( + type=genai_types.Type.OBJECT, + properties={'value': params}, + required=['value'], + description=params.description, + ) function = genai_types.FunctionDeclaration( name=tool.name, @@ -1089,6 +1097,26 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool: ) return genai_types.Tool(function_declarations=[function]) + @staticmethod + def _is_scalar_or_array_root(schema: genai_types.Schema | None) -> bool: + """True if schema root is scalar or array (Gemini rejects these for tool params).""" + if not schema or not schema.type: + return False + t = schema.type + scalar_or_array = { + genai_types.Type.STRING, + genai_types.Type.NUMBER, + genai_types.Type.INTEGER, + genai_types.Type.BOOLEAN, + genai_types.Type.ARRAY, + } + if t not in scalar_or_array: + return False + # Object with properties is fine (already has properties) + if t == genai_types.Type.OBJECT: + return False + return True + def _convert_schema_property( self, input_schema: dict[str, object] | None, defs: dict[str, object] | None = None ) -> genai_types.Schema | None: @@ -1332,7 +1360,10 @@ async def generate(self, request: ModelRequest, ctx: ActionRunContext) -> ModelR ) else: response = await self._generate( - request_contents=request_contents, request_cfg=request_cfg, model_name=model_name, client=client + request_contents=request_contents, + request_cfg=request_cfg, + model_name=model_name, + client=client, ) response.usage = self._create_usage_stats(request=request, response=response) @@ -1508,6 +1539,7 @@ async def _streaming_generate( ) from e accumulated_content: list[Part] = [] + finish_reason: FinishReason | None = None async for response_chunk in generator: content = await self._contents_from_response(response_chunk) if content: # Only process if we have content @@ -1518,12 +1550,29 @@ async def _streaming_generate( role=Role.MODEL, ) ) + # Track finish_reason from last chunk (stream typically sends it in final chunk) + if response_chunk.candidates: + for c in response_chunk.candidates: + if c.finish_reason: + fr_name = c.finish_reason.name + if fr_name == 'STOP': + finish_reason = FinishReason.STOP + elif fr_name in ('MAX_TOKENS', 'MAX_OUTPUT_TOKENS'): + finish_reason = FinishReason.LENGTH + elif fr_name in ('SAFETY', 'RECITATION', 'BLOCKLIST', 'PROHIBITED_CONTENT', 'SPII'): + finish_reason = FinishReason.BLOCKED + elif fr_name == 'OTHER': + finish_reason = FinishReason.OTHER + else: + finish_reason = FinishReason.OTHER + break return ModelResponse( message=Message( role=Role.MODEL, content=accumulated_content, - ) + ), + finish_reason=finish_reason, ) @cached_property diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py index 08fd486006..32babfa0a0 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py @@ -278,7 +278,11 @@ def _to_gemini_custom(cls, part: Part | DocumentPart) -> genai.types.Part: return genai.types.Part() @classmethod - def from_gemini(cls, part: genai.types.Part, ref: str | None = None) -> Part: + def from_gemini( + cls, + part: genai.types.Part, + ref: str | None = None, + ) -> Part: """Maps a Gemini Part back to a Genkit Part. This method inspects the type of the Gemini Part and converts it into diff --git a/py/plugins/google-genai/test/models/googlegenai_gemini_test.py b/py/plugins/google-genai/test/models/googlegenai_gemini_test.py index 70516cf1f1..b9fba27e28 100644 --- a/py/plugins/google-genai/test/models/googlegenai_gemini_test.py +++ b/py/plugins/google-genai/test/models/googlegenai_gemini_test.py @@ -481,6 +481,39 @@ def test_gemini_model__create_tool( assert isinstance(gemini_tool, genai_types.Tool) +def test_gemini_model__create_tool_wraps_scalar_input_schema( + gemini_model_instance: GeminiModel, +) -> None: + """Scalar/array root input schemas are wrapped in object for Gemini. + + Gemini rejects tool params with type=STRING etc. LLMs always send + {"key": value} — we wrap scalar schemas in {"value": }. + """ + scalar_string_schema = genai_types.Schema( + type=genai_types.Type.STRING, + description='Echo input', + ) + tool_defined = ToolDefinition( + name='echo', + description='Echo the input string', + input_schema={'type': 'string'}, + output_schema=None, + ) + with patch.object( + gemini_model_instance, + '_convert_schema_property', + return_value=scalar_string_schema, + ): + gemini_tool = gemini_model_instance._create_tool(tool_defined) + + decl = gemini_tool.function_declarations[0] + params = decl.parameters + assert params.type == genai_types.Type.OBJECT + assert 'value' in params.properties + assert params.properties['value'].type == genai_types.Type.STRING + assert params.required == ['value'] + + @pytest.mark.parametrize( 'input_schema, defs, expected_schema', [ diff --git a/py/pyproject.toml b/py/pyproject.toml index 756acdb349..411bf53dd1 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -271,7 +271,7 @@ select = [ # Samples are demo code and can use blocking I/O for simplicity "samples/*/*.py" = ["ANN", "D", "E402", "ASYNC"] "samples/*/**/*.py" = ["ANN", "D", "E402", "ASYNC"] -"samples/*/src/*.py" = ["ANN", "D", "E402", "ASYNC"] +"samples/*/src/*.py" = ["ANN", "D", "E402", "ASYNC", "T201"] [tool.ruff.lint.isort] diff --git a/py/samples/README.md b/py/samples/README.md index 95dabb187e..1548ebc99e 100644 --- a/py/samples/README.md +++ b/py/samples/README.md @@ -35,6 +35,6 @@ Dev UI: http://localhost:4000. Most samples need `GEMINI_API_KEY`. See [plugins/ | `middleware` | Observe or modify model requests | | `output-formats` | Text, enum, JSON, array, and JSONL outputs | | `prompts` | `.prompt` files, variants, helpers, and streaming | -| `tool-interrupts` | Pause a tool for human approval | +| `tool-interrupts` | Trivia (`respond_example.py`) and bank approval (`approval_example.py`) — interrupt + resume | | `tracing` | Watch spans appear in real time | | `vertexai-imagen` | Generate an image with Vertex AI Imagen | diff --git a/py/samples/dynamic-tools/README.md b/py/samples/dynamic-tools/README.md deleted file mode 100644 index 5377eac513..0000000000 --- a/py/samples/dynamic-tools/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# Dynamic Tools - -Learn two related ideas: - -- `ai.dynamic_tool()` creates a tool at runtime. -- `ai.run()` traces a plain async function as a named step. - -```bash -export GEMINI_API_KEY=your-api-key -uv sync -uv run src/main.py -``` - -To inspect the same flows in Dev UI: - -```bash -genkit start -- uv run src/main.py -``` diff --git a/py/samples/dynamic-tools/pyproject.toml b/py/samples/dynamic-tools/pyproject.toml deleted file mode 100644 index 9660490674..0000000000 --- a/py/samples/dynamic-tools/pyproject.toml +++ /dev/null @@ -1,18 +0,0 @@ -[project] -name = "dynamic-tools" -version = "0.2.0" -requires-python = ">=3.10" -dependencies = [ - "genkit", - "genkit-plugin-google-genai", - "pydantic>=2.0.0", - "structlog>=24.0.0", - "uvloop>=0.21.0", -] - -[build-system] -build-backend = "hatchling.build" -requires = ["hatchling"] - -[tool.hatch.build.targets.wheel] -packages = ["src"] diff --git a/py/samples/dynamic-tools/src/main.py b/py/samples/dynamic-tools/src/main.py deleted file mode 100644 index bb200eb1aa..0000000000 --- a/py/samples/dynamic-tools/src/main.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Dynamic tools - create tools at runtime and trace plain functions.""" - -from pydantic import BaseModel, Field - -from genkit import Genkit -from genkit.plugins.google_genai import GoogleAI - -ai = Genkit(plugins=[GoogleAI()], model='googleai/gemini-2.5-flash') - - -class DynamicToolInput(BaseModel): - """Input for runtime tool creation.""" - - value: int = Field(default=5, description='Value to square with a runtime tool') - - -class RunStepInput(BaseModel): - """Input for traced step demo.""" - - text: str = Field(default='hello dynamic tools', description='Text to process with traced steps') - - -@ai.flow() -async def dynamic_tool_demo(input: DynamicToolInput) -> str: - """Create a tool at runtime and call it immediately.""" - - async def square(value: int) -> int: - return value * value - - tool = ai.dynamic_tool(name='square', description='Square a number', fn=square) - result = await tool.run(input.value) - return f'{input.value} squared is {result.response}' - - -@ai.flow() -async def run_step_demo(input: RunStepInput) -> dict[str, int | str]: - """Wrap plain async functions in traceable `ai.run()` steps.""" - - async def normalize() -> str: - return input.text.strip().lower() - - async def count_words() -> int: - return len(normalized.split()) - - normalized = await ai.run(name='normalize_text', fn=normalize) - word_count = await ai.run(name='count_words', fn=count_words) - return {'normalized': normalized, 'word_count': word_count} - - -async def main() -> None: - """Run both dynamic tool demos once.""" - - print(await dynamic_tool_demo(DynamicToolInput())) # noqa: T201 - print(await run_step_demo(RunStepInput())) # noqa: T201 - - -if __name__ == '__main__': - ai.run_main(main()) diff --git a/py/samples/tool-interrupts/README.md b/py/samples/tool-interrupts/README.md index f048c2aaca..ca1c5825dc 100644 --- a/py/samples/tool-interrupts/README.md +++ b/py/samples/tool-interrupts/README.md @@ -1,19 +1,33 @@ -# Tool Interrupts +# Tool interrupts -Human-in-the-loop: `ctx.interrupt()` and `tool_response()` let the model pause for user input, then continue. +Usually the generate loop calls a tool, your code runs, returns a value, and then restarts the generate loop again until it terminates on its own or hits a stopping condition. + +With an **interrupt**, the tool **doesn’t** finish that way: a tool can`raise Interrupt(...)` and **hand control back to your application**. Think of it as the tool saying “you handle this step”—collect input, call another service, enforce policy—**instead of** returning a final tool result in one shot. + +The Genkit SDK **stops that generation turn**, surfaces the pending tool call (with your payload on `**metadata["interrupt"]`**), and you `**generate` again** later with the **same `messages`** plus `**resume_respond**` or `resume_restart.`Either you **inject the tool outcome** (respond) or you **ask the SDK to run the tool again** with new input and metadata. + +## Samples + +`**respond_example.py`** — Trivia: the “tool” hands off to the CLI; your answer **is** the tool result (`respond_to_interrupt` + `resume_respond`). Prompt: `prompts/trivia_host_cli.prompt`. + +`**approval_example.py`** — Bank demo: `**y**` restarts the tool (`resume_restart`); `**n**` declines with respond (`resume_respond`). `USER_MESSAGE` is hardcoded; you only type y/n. Prompt: `prompts/bank_transfer_host_cli.prompt`. + +## Run + +`**GEMINI_API_KEY**` (Google AI plugin): ```bash export GEMINI_API_KEY=your-api-key uv sync -uv run src/main.py +uv run src/respond_example.py +uv run src/approval_example.py ``` -This launches a small interactive CLI trivia session. - -To inspect the same flow in Dev UI instead: +From repo root: ```bash -genkit start -- uv run src/main.py +uv run --directory py/samples/tool-interrupts python src/respond_example.py +uv run --directory py/samples/tool-interrupts python src/approval_example.py ``` -Try `play_trivia`. +Wire detail: `**MESSAGE_SHAPES.md**`. \ No newline at end of file diff --git a/py/samples/tool-interrupts/prompts/bank_transfer_host_cli.prompt b/py/samples/tool-interrupts/prompts/bank_transfer_host_cli.prompt new file mode 100644 index 0000000000..a197447ea8 --- /dev/null +++ b/py/samples/tool-interrupts/prompts/bank_transfer_host_cli.prompt @@ -0,0 +1,14 @@ +--- +model: googleai/gemini-3-flash-preview +--- + +You are a **bank assistant** in a demo CLI. Help users check balances in plain language, but **any outgoing transfer of money** must go through the approval tool first. + +When the user asks you to **send, wire, or transfer** funds to a person or account, call `request_transfer` with: +- `to_account`: who gets the money (name or masked account id) +- `amount_usd`: amount (e.g. `250.00`) +- `memo`: short reason (e.g. `rent`, `invoice #12`) + +Do **not** pretend the transfer already happened until the user has approved it in the CLI. Keep replies short. + +[user joined online banking] diff --git a/py/samples/tool-interrupts/prompts/trivia_host_cli.prompt b/py/samples/tool-interrupts/prompts/trivia_host_cli.prompt new file mode 100644 index 0000000000..c613dd049a --- /dev/null +++ b/py/samples/tool-interrupts/prompts/trivia_host_cli.prompt @@ -0,0 +1,9 @@ +--- +model: googleai/gemini-3-flash-preview +--- + +You are a trivia game host. Cheerfully greet the user when they first join and ask them for the theme of the trivia game. Suggest a few theme options, but they do not have to use them. + +When the user is ready for a question, call `present_questions` so the UI can show the question and multiple-choice answers. After the user answers, tell them if they were right or wrong. Be dramatic but brief. + +[user joined the game] diff --git a/py/samples/tool-interrupts/src/approval_example.py b/py/samples/tool-interrupts/src/approval_example.py new file mode 100644 index 0000000000..220ed58264 --- /dev/null +++ b/py/samples/tool-interrupts/src/approval_example.py @@ -0,0 +1,188 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""**Bank transfer approval** — human-in-the-loop before a transfer tool finishes. + +The model calls ``request_transfer``; the CLI asks **approve (y)** or **decline (n)**. +**Approve** → ``tool.restart(...)`` / ``resume_restart`` so the tool **runs again** +with ``ToolRunContext.is_resumed``. **Decline** → ``respond_to_interrupt`` / +``resume_respond`` (no second tool run). + +Opening prompt, then one **canned user message** (no typing) so the model calls +``request_transfer``; you still answer **y/n** for approval. Run:: + + uv run src/approval_example.py + +For the trivia-only **respond** demo, see ``respond_example.py``. See README.md. +""" + +from pathlib import Path + +from pydantic import BaseModel, Field + +from genkit import ( + Genkit, + Interrupt, + ToolRunContext, + respond_to_interrupt, +) +from genkit.model import ModelResponse +from genkit.plugins.google_genai import GoogleAI # pyright: ignore[reportMissingImports] + +_PROMPTS_DIR = Path(__file__).resolve().parent.parent / 'prompts' + +_BAR = '=' * 52 +_RULE = '-' * 52 + +# Canned user line so the demo always triggers a transfer tool call without stdin. +USER_MESSAGE = 'Please wire $250.00 to Jane Doe (account ending in 4521) for April rent.' + + +def _print_intro() -> None: + print(f'\n{_BAR}') + print(' Bank transfer demo — outgoing wires need your approval in the CLI.') + print(_BAR) + print(' 1) The banker speaks first.') + print(' 2) A scripted user message asks for a wire; the model calls the transfer tool.') + print(' 3) When asked y/n: yes = approve (tool runs again); no = decline.') + print(f'{_BAR}\n') + + +def _print_scripted_user_turn() -> None: + print(_RULE) + print('Scripted user message (see USER_MESSAGE in source):') + print(_RULE) + + +def _print_waiting_opening() -> None: + print('Starting: banker opening (please wait)...\n') + + +def _print_model_turn(label: str, r: ModelResponse) -> None: + print(f'\n[{label}]') + if r.text: + print(r.text) + + +def _print_transfer_approval_prompt(summary: str) -> None: + print('\n' + _BAR) + print(' TRANSFER APPROVAL — y = approve (rerun tool) | n = decline') + print(_BAR) + if summary: + print(f' {summary}') + + +def _print_unexpected_tool(name: str) -> None: + print(f'Unexpected tool: {name!r}') + + +ai = Genkit( + plugins=[GoogleAI()], + model='googleai/gemini-3-flash-preview', + prompt_dir=_PROMPTS_DIR, +) + + +class TransferRequest(BaseModel): + """Wire transfer the user asked for; shown again before approval.""" + + to_account: str = Field(description='recipient name or masked account identifier') + amount_usd: str = Field(description='amount as a string, e.g. 250.00') + memo: str = Field(default='', description='short reason (rent, invoice, gift, …)') + + +@ai.tool() +async def request_transfer(body: TransferRequest, ctx: ToolRunContext) -> dict: + """First run: interrupt for approval. After approval: return confirmation with metadata.""" + if not ctx.is_resumed(): + line = f'Wire ${body.amount_usd} to {body.to_account}' + if body.memo: + line = f'{line} — {body.memo}' + raise Interrupt({ + 'summary': line, + 'to_account': body.to_account, + 'amount_usd': body.amount_usd, + 'memo': body.memo, + 'needs_approval': True, + }) + return {'status': 'confirmed', 'resumed': ctx.resumed_metadata} + + +async def interactive_restart_cli() -> None: + """Opening prompt, scripted user line, then transfer approval via ``request_transfer``.""" + + _print_intro() + + _print_waiting_opening() + response = await ai.prompt('bank_transfer_host_cli')() + messages = response.messages + _print_model_turn('Banker (opening)', response) + + _print_scripted_user_turn() + user_said = USER_MESSAGE + print(f'\nYou: {user_said}\n') + + response = await ai.generate( + messages=messages, + prompt=user_said, + tools=[request_transfer], + ) + messages = response.messages + _print_model_turn('Banker', response) + + while response.interrupts: + interrupt = response.interrupts[0] + name = interrupt.tool_request.name + if name != request_transfer.name: + _print_unexpected_tool(name) + return + + meta = interrupt.metadata.get('interrupt') if interrupt.metadata else True + summary = meta.get('summary', '') if isinstance(meta, dict) else '' + _print_transfer_approval_prompt(summary) + ans = input('Approve transfer? [y/N]: ').strip().lower() + + if ans in ('y', 'yes'): + restart = request_transfer.restart( + interrupt=interrupt, + resumed_metadata={'via': 'cli', 'path': 'restart'}, + ) + response = await ai.generate( + messages=messages, + resume_restart=restart, + tools=[request_transfer], + ) + else: + decline_response = respond_to_interrupt( + {'status': 'declined'}, + interrupt=interrupt, + metadata={'source': 'cli', 'path': 'respond_decline'}, + ) + response = await ai.generate( + messages=messages, + resume_respond=decline_response, + tools=[request_transfer], + ) + messages = response.messages + _print_model_turn('Banker (after your decision)', response) + + +async def main() -> None: + await interactive_restart_cli() + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/samples/tool-interrupts/src/main.py b/py/samples/tool-interrupts/src/main.py deleted file mode 100755 index 4ce87963eb..0000000000 --- a/py/samples/tool-interrupts/src/main.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Tool interrupts - Human-in-the-loop with ctx.interrupt() and tool_response(). See README.md.""" - -import os - -from pydantic import BaseModel, Field - -from genkit import ( - Genkit, - ToolRunContext, - tool_response, -) -from genkit.plugins.google_genai import GoogleAI -from genkit.plugins.google_genai.models import gemini - -ai = Genkit( - plugins=[GoogleAI()], - model=f'googleai/{gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW}', -) - - -class TriviaQuestions(BaseModel): - """Trivia questions.""" - - question: str = Field(description='the main question') - answers: list[str] = Field(description='list of multiple choice answers (typically 4), 1 correct 3 wrong') - - -@ai.tool() -async def present_questions(questions: TriviaQuestions, ctx: ToolRunContext) -> None: - """Presents questions to the user and responds with the selected answer.""" - ctx.interrupt(questions.model_dump()) - - -@ai.flow() -async def play_trivia(theme: str = 'Science') -> str: - """Plays a trivia game with the user.""" - response = await ai.generate( - prompt='You are a trivia game host. Cheerfully greet the user when they ' - + f'first join. The user has selected the theme: "{theme}". ' - + 'Call `present_questions` tool with questions and the tools will present ' - + 'the questions in a nice UI. The user will pick an answer and then you ' - + 'tell them if they were right or wrong. Be dramatic (but terse)! It is a ' - + 'show!\n\n[user joined the game]', - tools=['present_questions'], - ) - - # Check for interrupts and return the question to the user - if len(response.interrupts): - request = response.interrupts[0] - question_data = request.tool_request.input - if question_data: - # For a full interactive flow, you would typically: - # 1. Prompt the user for their answer here (e.g., using input()). - # 2. Call tool_response(request, user_answer) to resume the AI conversation. - # 3. Regenerate with the tool_response. - - # Prepend the greeting/text response if available - text_response = (response.text + '\n\n') if response.text else '' - question = question_data.get('question') - answers = question_data.get('answers') - return f'{text_response}INTERRUPTED: {question}\nAnswers: {answers}' - return 'INTERRUPTED: (No input data)' - - return response.text - - -async def main() -> None: - """Dev mode: return immediately; run_main keeps Dev UI alive. Standalone: interactive trivia CLI.""" - if os.environ.get('GENKIT_ENV') == 'dev': - return - try: - response = await ai.generate( - prompt='You are a trivia game host. Cheerfully greet the user when they ' - + 'first join and ask them for the theme of the trivia game. Suggest ' - + 'a few theme options, but they do not have to use them. When the user is ready, call ' - + '`present_questions` so the UI can show the question and answers. ' - + 'After the user answers, tell them if they were right or wrong. Be dramatic but brief.\n\n' - + '[user joined the game]', - ) - messages = response.messages - while True: - response = await ai.generate( - messages=messages, - prompt=input('Say: '), - tools=['present_questions'], - ) - messages = response.messages - if len(response.interrupts) > 0: - request = response.interrupts[0] - tr = tool_response(request, input('Your answer (number): ')) - response = await ai.generate( - messages=messages, - tool_responses=[tr], - tools=['present_questions'], - ) - messages = response.messages - except Exception as error: - print(f'Set GEMINI_API_KEY to a valid value before running this sample directly.\n{error}') # noqa: T201 - - -if __name__ == '__main__': - ai.run_main(main()) diff --git a/py/samples/tool-interrupts/src/respond_example.py b/py/samples/tool-interrupts/src/respond_example.py new file mode 100755 index 0000000000..741929fd99 --- /dev/null +++ b/py/samples/tool-interrupts/src/respond_example.py @@ -0,0 +1,161 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tool interrupts — trivia via ``present_questions`` and ``respond_to_interrupt``. + +``present_questions`` raises ``Interrupt`` with the question payload → the user +picks an answer → ``respond_to_interrupt(pick, interrupt=…, metadata=…)`` → second +``generate`` with ``resume_respond``. + +For **bank-transfer-style** tool approval (restart path), run ``approval_example.py`` instead. + +Run: ``uv run src/respond_example.py``. See README.md. +""" + +from pathlib import Path + +from pydantic import BaseModel, Field + +from genkit import ( + Genkit, + Interrupt, + respond_to_interrupt, +) +from genkit.model import ModelResponse +from genkit.plugins.google_genai import GoogleAI # pyright: ignore[reportMissingImports] + +_PROMPTS_DIR = Path(__file__).resolve().parent.parent / 'prompts' + +ai = Genkit( + plugins=[GoogleAI()], + model='googleai/gemini-3-flash-preview', + prompt_dir=_PROMPTS_DIR, +) + + +class TriviaQuestions(BaseModel): + """Payload passed into ``present_questions`` when the model calls the tool.""" + + question: str = Field(description='the main question') + answers: list[str] = Field( + description='list of multiple choice answers (typically 4), 1 correct 3 wrong', + ) + + +@ai.tool() +async def present_questions(questions: TriviaQuestions) -> None: + """Presents questions to the user and responds with the selected answer.""" + raise Interrupt(questions.model_dump(mode='json')) + + +DEMO_TOOLS = [present_questions] + + +async def interactive_trivia_cli() -> None: + """Run the CLI: opening turn, then chat with trivia interrupts (respond path).""" + + def show(label: str, r: ModelResponse) -> None: + """Print one model turn (what the host said).""" + print(f'\n[{label}]') + if r.text: + print(r.text) + + quit_words = frozenset({'q', 'quit', 'exit', 'bye'}) + bar = '=' * 52 + + print(f'\n{bar}') + print(' Tool interrupt demo — trivia (respond path)') + print(bar) + print(' 1) Host speaks first (you do not type yet).') + print(' 2) Then you chat one line at a time.') + print(' 3) When you see numbered answers, reply with a number.') + print(' Say quit / exit / q / bye anytime to stop.') + print(f'{bar}\n') + + print('Starting: host opening (please wait)...\n') + response = await ai.prompt('trivia_host_cli')() + messages = response.messages + show('Host (opening)', response) + + print('-' * 52) + print('Your turn — reply to the host above, or type quit to leave.') + print('-' * 52) + + while True: + user_said = input('\nYou: ').strip() + if user_said.lower() in quit_words: + print('Goodbye.') + return + if not user_said: + print('Empty line — type a message, or quit to exit.') + continue + + response = await ai.generate( + messages=messages, + prompt=user_said, + tools=DEMO_TOOLS, + ) + messages = response.messages + show('Host', response) + + while response.interrupts: + interrupt = response.interrupts[0] + name = interrupt.tool_request.name + + if name != present_questions.name: + print(f'Unexpected tool: {name!r}') + return + + payload = interrupt.tool_request.input + if payload is None: + print('Interrupt with no tool input.') + return + trivia = TriviaQuestions.model_validate(payload) + + n = len(trivia.answers) + print('\n' + bar) + print(' QUESTION — answer with a number') + print(bar) + print(trivia.question) + for i, ans in enumerate(trivia.answers, start=1): + print(f' {i}. {ans}') + print(f'Enter 1–{n}.') + + pick = input('Your choice (number): ').strip() + if pick.lower() in quit_words: + print('Goodbye.') + return + + interrupt_response = respond_to_interrupt( + pick, + interrupt=interrupt, + metadata={'source': 'cli', 'path': 'respond'}, + ) + response = await ai.generate( + messages=messages, + resume_respond=[interrupt_response], + tools=DEMO_TOOLS, + ) + messages = response.messages + show('Host (after your answer)', response) + + +async def main() -> None: + await interactive_trivia_cli() + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/uv.lock b/py/uv.lock index 97f85fe78a..1c012b2d5b 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -12,7 +12,6 @@ resolution-markers = [ [manifest] members = [ "context", - "dynamic-tools", "evaluators", "fastapi-bugbot", "flask-hello", @@ -1223,27 +1222,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/09/d09dfaa2110884284be6006b7586ea519f7391de58ed5428f2bf457bcd03/dotpromptz_handlebars-0.1.8-cp310-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f23498821610d443a67c860922aba00d20bdd80b8421bfef0ceff07b713f8198", size = 666257, upload-time = "2026-01-30T06:44:46.929Z" }, ] -[[package]] -name = "dynamic-tools" -version = "0.2.0" -source = { editable = "samples/dynamic-tools" } -dependencies = [ - { name = "genkit" }, - { name = "genkit-plugin-google-genai" }, - { name = "pydantic" }, - { name = "structlog" }, - { name = "uvloop" }, -] - -[package.metadata] -requires-dist = [ - { name = "genkit", editable = "packages/genkit" }, - { name = "genkit-plugin-google-genai", editable = "plugins/google-genai" }, - { name = "pydantic", specifier = ">=2.0.0" }, - { name = "structlog", specifier = ">=24.0.0" }, - { name = "uvloop", specifier = ">=0.21.0" }, -] - [[package]] name = "email-validator" version = "2.3.0"