Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ The Python samples are intentionally short and beginner-friendly. Use this table
| `middleware` | Middleware | Observe or modify model requests with `use=[...]` |
| `output-formats` | Structured Output | Text, enum, JSON object, array, and JSONL outputs |
| `prompts` | Prompt, Dotprompt | Use `.prompt` files, helpers, variants, and streaming |
| `tool-interrupts` | Tool | Pause tool execution for human approval |
| `tool-interrupts` | Tool | Trivia (`respond_example.py`) and bank approval (`approval_example.py`) for interrupt / resume |
| `tracing` | Tracing | Watch spans appear in the Dev UI as they start |
| `vertexai-imagen` | Vertex AI, Image Generation | Generate an image with Vertex AI Imagen |

Expand Down
2 changes: 1 addition & 1 deletion py/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

::: genkit.PublicError

::: genkit.ToolInterruptError
::: genkit.Interrupt

::: genkit.Message

Expand Down
2 changes: 1 addition & 1 deletion py/docs/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Types exported from genkit, genkit.model, genkit.embedder, genkit.plugin_api, an

::: genkit.PublicError

::: genkit.ToolInterruptError
::: genkit.Interrupt

::: genkit.Message

Expand Down
14 changes: 9 additions & 5 deletions py/packages/genkit/src/genkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
ExecutablePrompt,
ModelStreamResponse,
PromptGenerateOptions,
ResumeOptions,
)
from genkit._ai._tools import ToolInterruptError, ToolRunContext, tool_response
from genkit._ai._tools import (
Interrupt,
Tool,
ToolRunContext,
respond_to_interrupt,
)
from genkit._core._action import Action, StreamResponse
from genkit._core._error import GenkitError, PublicError
from genkit._core._model import Document
Expand Down Expand Up @@ -93,7 +97,9 @@
# Errors
'GenkitError',
'PublicError',
'ToolInterruptError',
'Interrupt',
'Tool',
'respond_to_interrupt',
# Content types
'Constrained',
'CustomPart',
Expand Down Expand Up @@ -126,9 +132,7 @@
'ActionRunContext',
'ExecutablePrompt',
'PromptGenerateOptions',
'ResumeOptions',
'ToolRunContext',
'tool_response',
'ModelRequest',
'ModelResponse',
'ModelResponseChunk',
Expand Down
126 changes: 81 additions & 45 deletions py/packages/genkit/src/genkit/_ai/_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import socket
import threading
import uuid
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import Awaitable, Callable, Coroutine, Sequence
from pathlib import Path
from typing import Any, ParamSpec, TypeVar, cast, overload
from typing import Any, TypeVar, cast, overload

import anyio
import uvicorn
Expand Down Expand Up @@ -71,7 +71,7 @@
ResourceOptions,
define_resource,
)
from genkit._ai._tools import define_tool
from genkit._ai._tools import Tool, define_interrupt, define_tool
from genkit._core._action import Action, ActionKind, ActionRunContext
from genkit._core._background import (
BackgroundAction,
Expand Down Expand Up @@ -107,6 +107,8 @@
Part,
SpanMetadata,
ToolChoice,
ToolRequestPart,
ToolResponsePart,
)

from ._decorators import _FlowDecorator, _FlowDecoratorWithChunk
Expand All @@ -118,7 +120,7 @@
InputT = TypeVar('InputT')
OutputT = TypeVar('OutputT')
ChunkT = TypeVar('ChunkT')
P = ParamSpec('P')

R = TypeVar('R')
T = TypeVar('T')

Expand Down Expand Up @@ -260,16 +262,45 @@ def define_dynamic_action_provider(
metadata=metadata,
)

def tool(
self, name: str | None = None, description: str | None = None
) -> Callable[[Callable[P, T]], Callable[P, T]]:
def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable[..., Any]], Tool]:
"""Decorator to register a function as a tool."""

def wrapper(func: Callable[P, T]) -> Callable[P, T]:
def wrapper(func: Callable[..., Any]) -> Tool:
return define_tool(self.registry, func, name, description)

return wrapper

def define_interrupt(
self,
name: str,
*,
input_schema: type[BaseModel] | dict[str, object] | None = None,
description: str | None = None,
) -> Tool:
"""Register an interrupt tool that always pauses for user input.

Args:
name: Tool name
input_schema: Optional input schema (Pydantic model or JSON schema dict)
description: Tool description

Returns:
The registered interrupt tool

Example:
ask_user = ai.define_interrupt(
name='ask_user',
input_schema=Question,
description='Ask the user a question',
)
"""
return define_interrupt(
self.registry,
name,
description=description,
input_schema=input_schema,
)

def define_evaluator(
self,
*,
Expand Down Expand Up @@ -393,7 +424,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand Down Expand Up @@ -421,7 +452,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand Down Expand Up @@ -449,7 +480,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand Down Expand Up @@ -477,7 +508,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand All @@ -503,7 +534,7 @@ def define_prompt(
max_turns: int | None = None,
return_tool_requests: bool | None = None,
metadata: dict[str, object] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
tool_choice: ToolChoice | None = None,
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
Expand Down Expand Up @@ -743,10 +774,12 @@ async def generate(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
tool_responses: list[Part] | None = None,
resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None,
resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None,
resume_metadata: dict[str, Any] | None = None,
config: dict[str, object] | ModelConfig | None = None,
max_turns: int | None = None,
context: dict[str, object] | None = None,
Expand All @@ -768,10 +801,12 @@ async def generate(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
tool_responses: list[Part] | None = None,
resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None,
resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None,
resume_metadata: dict[str, Any] | None = None,
config: dict[str, object] | ModelConfig | None = None,
max_turns: int | None = None,
context: dict[str, object] | None = None,
Expand All @@ -791,10 +826,12 @@ async def generate(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
tool_responses: list[Part] | None = None,
resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None,
resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None,
resume_metadata: dict[str, Any] | None = None,
config: dict[str, object] | ModelConfig | None = None,
max_turns: int | None = None,
context: dict[str, object] | None = None,
Expand All @@ -806,7 +843,12 @@ async def generate(
use: list[ModelMiddleware] | None = None,
docs: list[Document] | None = None,
) -> ModelResponse[Any]:
"""Generate text or structured data using a language model."""
"""Generate text or structured data using a language model.

``tools`` is typed as ``Sequence`` rather than ``list`` because ``Sequence``
is covariant: ``list[Tool]`` or ``list[str]`` are both assignable to
``Sequence[str | Tool]``, but not to ``list[str | Tool]``.
"""
return await generate_action(
self.registry,
await to_generate_action_options(
Expand All @@ -819,7 +861,9 @@ async def generate(
tools=tools,
return_tool_requests=return_tool_requests,
tool_choice=tool_choice,
tool_responses=tool_responses,
resume_respond=resume_respond,
resume_restart=resume_restart,
resume_metadata=resume_metadata,
config=config,
max_turns=max_turns,
output_format=output_format,
Expand All @@ -843,9 +887,12 @@ def generate_stream(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None,
resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None,
resume_metadata: dict[str, Any] | None = None,
config: dict[str, object] | ModelConfig | None = None,
max_turns: int | None = None,
context: dict[str, object] | None = None,
Expand All @@ -868,9 +915,12 @@ def generate_stream(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None,
resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None,
resume_metadata: dict[str, Any] | None = None,
config: dict[str, object] | ModelConfig | None = None,
max_turns: int | None = None,
context: dict[str, object] | None = None,
Expand All @@ -891,9 +941,12 @@ def generate_stream(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None,
resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None,
resume_metadata: dict[str, Any] | None = None,
config: dict[str, object] | ModelConfig | None = None,
max_turns: int | None = None,
context: dict[str, object] | None = None,
Expand Down Expand Up @@ -922,6 +975,9 @@ async def _run_generate() -> ModelResponse[Any]:
tools=tools,
return_tool_requests=return_tool_requests,
tool_choice=tool_choice,
resume_respond=resume_respond,
resume_restart=resume_restart,
resume_metadata=resume_metadata,
config=config,
max_turns=max_turns,
output_format=output_format,
Expand Down Expand Up @@ -1055,26 +1111,6 @@ def current_context() -> dict[str, Any] | None:
"""Get the current execution context, or None if not in an action."""
return ActionRunContext._current_context() # pyright: ignore[reportPrivateUsage]

def dynamic_tool(
self,
*,
name: str,
fn: Callable[..., object],
description: str | None = None,
metadata: dict[str, object] | None = None,
) -> Action:
"""Create an unregistered tool action for passing directly to generate()."""
tool_meta: dict[str, object] = metadata.copy() if metadata else {}
tool_meta['type'] = 'tool'
tool_meta['dynamic'] = True
return Action(
kind=ActionKind.TOOL,
name=name,
fn=fn, # type: ignore[arg-type] # dynamic tools may be sync
description=description,
metadata=tool_meta,
)

async def flush_tracing(self) -> None:
"""Flush all pending trace spans to exporters."""
provider = trace_api.get_tracer_provider()
Expand Down Expand Up @@ -1132,7 +1168,7 @@ async def generate_operation(
prompt: str | list[Part] | None = None,
system: str | list[Part] | None = None,
messages: list[Message] | None = None,
tools: list[str] | None = None,
tools: Sequence[str | Tool] | None = None,
return_tool_requests: bool | None = None,
tool_choice: ToolChoice | None = None,
config: dict[str, object] | ModelConfig | None = None,
Expand Down
Loading
Loading