diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index a1902bb150..de864989e0 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -50,11 +50,11 @@ Message, ModelConfig, ModelFn, - ModelMiddleware, ModelResponse, ModelResponseChunk, define_model, ) +from genkit._core._middleware._base import BaseMiddleware from genkit._ai._prompt import ( ExecutablePrompt, ModelStreamResponse, @@ -395,7 +395,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: list[str] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, input_schema: type[InputT], output_schema: type[OutputT], @@ -423,7 +423,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: list[str] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, input_schema: type[InputT], output_schema: dict[str, object] | str | None = None, @@ -451,7 +451,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: list[str] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, input_schema: dict[str, object] | str | None = None, output_schema: type[OutputT], @@ -479,7 +479,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: list[str] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, input_schema: dict[str, object] | str | None = None, output_schema: dict[str, object] | str | None = None, @@ -505,7 +505,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: list[str] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, input_schema: type | dict[str, object] | str | None = None, output_schema: type | dict[str, object] | str | None = None, @@ -753,7 +753,7 @@ async def generate( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, ) -> ModelResponse[OutputT]: ... @@ -778,7 +778,7 @@ async def generate( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, ) -> ModelResponse[Any]: ... @@ -801,7 +801,7 @@ async def generate( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, ) -> ModelResponse[Any]: """Generate text or structured data using a language model.""" @@ -852,7 +852,7 @@ def generate_stream( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, timeout: float | None = None, ) -> ModelStreamResponse[OutputT]: ... @@ -877,7 +877,7 @@ def generate_stream( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, timeout: float | None = None, ) -> ModelStreamResponse[Any]: ... @@ -900,7 +900,7 @@ def generate_stream( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, timeout: float | None = None, ) -> ModelStreamResponse[Any]: @@ -1141,7 +1141,7 @@ async def generate_operation( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, ) -> Operation: """Generate content using a long-running model, returning an Operation to poll.""" diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index 8a7e753083..dd8b101a16 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -17,23 +17,27 @@ """Generate action.""" import copy -import inspect import re -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import Any, cast from pydantic import BaseModel from genkit._ai._formats._types import FormatDef, Formatter from genkit._ai._messages import inject_instructions -from genkit._ai._middleware import augment_with_context from genkit._ai._model import ( Message, - ModelMiddleware, ModelRequest, ModelResponse, ModelResponseChunk, ) +from genkit._core._middleware._augment_with_context import augment_with_context +from genkit._core._middleware._base import ( + BaseMiddleware, + GenerateHookParams, + ModelHookParams, + ToolHookParams, +) from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources from genkit._ai._tools import ToolInterruptError from genkit._core._action import Action, ActionKind, ActionRunContext @@ -57,6 +61,29 @@ logger = get_logger(__name__) +async def _chain_tool_middleware( + middleware: list[BaseMiddleware], + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[tuple[Part | None, Part | None]]], +) -> tuple[Part | None, Part | None]: + """Run the tool middleware chain and return (response_part, interrupt_part).""" + runner: Callable[[ToolHookParams], Awaitable[tuple[Part | None, Part | None]]] = next_fn + for mw in reversed(middleware): + _mw = mw + _inner = runner + + async def run_next( + p: ToolHookParams, + *, + _m: BaseMiddleware = _mw, + _i: Callable[[ToolHookParams], Awaitable[tuple[Part | None, Part | None]]] = _inner, + ) -> tuple[Part | None, Part | None]: + return await _m.wrap_tool(p, _i) + + runner = run_next + return await runner(params) + + # Matches data URIs: everything up to the first comma is the media-type + # parameters (e.g. "data:audio/L16;codec=pcm;rate=24000;base64,"). _DATA_URI_RE = re.compile(r'data:[^,]{0,200},(?=.{100})', re.ASCII) @@ -109,7 +136,7 @@ async def generate_action( on_chunk: Callable[[ModelResponseChunk], None] | None = None, message_index: int = 0, current_turn: int = 0, - middleware: list[ModelMiddleware] | None = None, + middleware: list[BaseMiddleware] | None = None, context: dict[str, Any] | None = None, ) -> ModelResponse: """Execute a generation request with tool calling and middleware support.""" @@ -198,55 +225,64 @@ def wrapper(chunk: ModelResponseChunk) -> None: if raw_request.docs and not supports_context: middleware.append(augment_with_context()) - async def dispatch( - index: int, + normalized_mw: list[BaseMiddleware] = list(middleware) + + async def dispatch_generate( + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + """Chain wrap_generate middleware and call next_fn.""" + runner: Callable[[GenerateHookParams], Awaitable[ModelResponse]] = next_fn + for mw in reversed(normalized_mw): + _mw = mw + _inner = runner + + async def run_next( + p: GenerateHookParams, + *, + _m: BaseMiddleware = _mw, + _i: Callable[[GenerateHookParams], Awaitable[ModelResponse]] = _inner, + ) -> ModelResponse: + return await _m.wrap_generate(p, _i) + + runner = run_next + return await runner(params) + + async def dispatch_model( req: ModelRequest, - ctx: ActionRunContext, chunk_callback: Callable[[ModelResponseChunk], None] | None, ) -> ModelResponse: - """Dispatch request through middleware chain to the model.""" - if not middleware or index == len(middleware): - # End of the chain, call the original model action + async def run_model(params: ModelHookParams) -> ModelResponse: return ( await model.run( - input=req, - context=ctx.context, - on_chunk=cast(Callable[[object], None], chunk_callback) if chunk_callback else None, + input=params.request, + context=params.context, + on_chunk=cast(Callable[[object], None], params.on_chunk) if params.on_chunk else None, ) ).response - current_middleware = middleware[index] - n_params = len(inspect.signature(current_middleware).parameters) + runner: Callable[[ModelHookParams], Awaitable[ModelResponse]] = run_model + for mw in reversed(normalized_mw): + _mw = mw + _inner = runner - if n_params == 4: - # Streaming middleware: (req, ctx, on_chunk, next) -> response - async def next_fn_streaming( - modified_req: ModelRequest | None = None, - modified_ctx: ActionRunContext | None = None, - modified_on_chunk: Callable[[ModelResponseChunk], None] | None = None, + async def run_next( + params: ModelHookParams, + *, + _mw: BaseMiddleware = _mw, + _inner: Callable[[ModelHookParams], Awaitable[ModelResponse]] = _inner, ) -> ModelResponse: - return await dispatch( - index + 1, - modified_req if modified_req else req, - modified_ctx if modified_ctx else ctx, - modified_on_chunk if modified_on_chunk is not None else chunk_callback, - ) + return await _mw.wrap_model(params, _inner) - return await current_middleware(req, ctx, chunk_callback, next_fn_streaming) - else: - # Simple middleware: (req, ctx, next) -> response - async def next_fn_simple( - modified_req: ModelRequest | None = None, - modified_ctx: ActionRunContext | None = None, - ) -> ModelResponse: - return await dispatch( - index + 1, - modified_req if modified_req else req, - modified_ctx if modified_ctx else ctx, - chunk_callback, - ) + runner = cast(Callable[[ModelHookParams], Awaitable[ModelResponse]], run_next) - return await current_middleware(req, ctx, next_fn_simple) + return await runner( + ModelHookParams( + request=req, + on_chunk=chunk_callback, + context=context or {}, + ) + ) # if resolving the 'resume' option above generated a tool message, stream it. if resumed_tool_message and on_chunk: @@ -257,101 +293,112 @@ async def next_fn_simple( ) ) - model_response = await dispatch( - 0, - request, - ActionRunContext(context=context), - wrap_chunks() if on_chunk else None, - ) - - def message_parser(msg: Message) -> Any: # noqa: ANN401 - if formatter is None: - return None - return formatter.parse_message(msg) + async def run_one_iteration(_params: GenerateHookParams) -> ModelResponse: + """Execute one turn of the generate loop (model call + optional tool resolution).""" + model_response = await dispatch_model( + request, + wrap_chunks() if on_chunk else None, + ) - # Extract schema_type for runtime Pydantic validation - schema_type = raw_request.output.schema_type if raw_request.output else None + def message_parser(msg: Message) -> Any: # noqa: ANN401 + if formatter is None: + return None + return formatter.parse_message(msg) + + # Extract schema_type for runtime Pydantic validation + schema_type = raw_request.output.schema_type if raw_request.output else None + + # Plugin returns ModelResponse directly. Framework sets request and + # any output format context (message_parser, schema_type) as private attrs. + response = model_response + response.request = request + if formatter: + response._message_parser = message_parser + if schema_type: + response._schema_type = schema_type + + logger.debug( + 'generate response', + response=_redact_data_uris(response.model_dump()), + ) - # Plugin returns ModelResponse directly. Framework sets request and - # any output format context (message_parser, schema_type) as private attrs. - response = model_response - response.request = request - if formatter: - response._message_parser = message_parser - if schema_type: - response._schema_type = schema_type + response.assert_valid() + generated_msg = response.message - logger.debug('generate response', response=_redact_data_uris(response.model_dump())) + if generated_msg is None: + # No message in response, return as-is + return response - response.assert_valid() - generated_msg = response.message + tool_requests = [x for x in generated_msg.content if x.root.tool_request] - if generated_msg is None: - # No message in response, return as-is - return response + if raw_request.return_tool_requests or len(tool_requests) == 0: + if len(tool_requests) == 0: + response.assert_valid_schema() + return response - tool_requests = [x for x in generated_msg.content if x.root.tool_request] + max_iters = raw_request.max_turns if raw_request.max_turns else DEFAULT_MAX_TURNS - if raw_request.return_tool_requests or len(tool_requests) == 0: - if len(tool_requests) == 0: - response.assert_valid_schema() - return response + if current_turn + 1 > max_iters: + raise GenerationResponseError( + response=response, + message=f'Exceeded maximum tool call iterations ({max_iters})', + status='ABORTED', + details={'request': request}, + ) - max_iters = raw_request.max_turns if raw_request.max_turns else DEFAULT_MAX_TURNS + ( + revised_model_msg, + tool_msg, + transfer_preamble, + ) = await resolve_tool_requests(registry, raw_request, generated_msg, middleware=normalized_mw) + + # if an interrupt message is returned, stop the tool loop and return a + # response. + if revised_model_msg: + interrupted_resp = response.model_copy(deep=False) + interrupted_resp.finish_reason = FinishReason.INTERRUPTED + interrupted_resp.finish_message = 'One or more tool calls resulted in interrupts.' + interrupted_resp.message = Message(revised_model_msg) + return interrupted_resp + + # If the loop will continue, stream out the tool response message... + if on_chunk and tool_msg: + on_chunk( + make_chunk( + Role.TOOL, + ModelResponseChunk( + role=tool_msg.role, + content=tool_msg.content, + ), + ) + ) - if current_turn + 1 > max_iters: - raise GenerationResponseError( - response=response, - message=f'Exceeded maximum tool call iterations ({max_iters})', - status='ABORTED', - details={'request': request}, - ) + next_request = copy.copy(raw_request) + next_messages = copy.copy(raw_request.messages) + next_messages.append(generated_msg) + if tool_msg: + next_messages.append(tool_msg) + next_request.messages = next_messages + if transfer_preamble: + next_request = apply_transfer_preamble(next_request, transfer_preamble) - ( - revised_model_msg, - tool_msg, - transfer_preamble, - ) = await resolve_tool_requests(registry, raw_request, generated_msg) - - # if an interrupt message is returned, stop the tool loop and return a - # response. - if revised_model_msg: - interrupted_resp = response.model_copy(deep=False) - interrupted_resp.finish_reason = FinishReason.INTERRUPTED - interrupted_resp.finish_message = 'One or more tool calls resulted in interrupts.' - interrupted_resp.message = Message(revised_model_msg) - return interrupted_resp - - # If the loop will continue, stream out the tool response message... - if on_chunk and tool_msg: - on_chunk( - make_chunk( - Role.TOOL, - ModelResponseChunk( - role=tool_msg.role, - content=tool_msg.content, - ), - ) + # then recursively call for another loop + return await generate_action( + registry, + raw_request=next_request, + middleware=normalized_mw, + current_turn=current_turn + 1, + message_index=message_index + 1, + on_chunk=on_chunk, + context=context, ) - next_request = copy.copy(raw_request) - next_messages = copy.copy(raw_request.messages) - next_messages.append(generated_msg) - if tool_msg: - next_messages.append(tool_msg) - next_request.messages = next_messages - if transfer_preamble: - next_request = apply_transfer_preamble(next_request, transfer_preamble) - - # then recursively call for another loop - return await generate_action( - registry, - raw_request=next_request, - # middleware: middleware, - current_turn=current_turn + 1, - message_index=message_index + 1, - on_chunk=on_chunk, + generate_params = GenerateHookParams( + options=raw_request, + request=request, + iteration=current_turn, ) + return await dispatch_generate(generate_params, run_one_iteration) def apply_format( @@ -595,7 +642,11 @@ def to_tool_definition(tool: Action) -> ToolDefinition: async def resolve_tool_requests( - registry: Registry, request: GenerateActionOptions, message: Message + registry: Registry, + request: GenerateActionOptions, + message: Message, + *, + middleware: list[BaseMiddleware] | None = None, ) -> tuple[Message | None, Message | None, GenerateActionOptions | None]: """Execute tool requests in a message, returning responses or interrupt info.""" # TODO(#4342): prompt transfer @@ -605,6 +656,7 @@ async def resolve_tool_requests( tool_dict[tool_name] = await resolve_tool(registry, tool_name) revised_model_message = message.model_copy(deep=True) + mw_list = middleware or [] has_interrupts = False response_parts: list[Part] = [] @@ -621,7 +673,16 @@ async def resolve_tool_requests( if tool_request.name not in tool_dict: raise RuntimeError(f'failed {tool_request.name} not found') tool = tool_dict[tool_request.name] - tool_response_part, interrupt_part = await _resolve_tool_request(tool, tool_req_root) + + if mw_list: + params = ToolHookParams(tool_request_part=tool_req_root, tool=tool) + + async def next_fn(p: ToolHookParams) -> tuple[Part | None, Part | None]: + return await _resolve_tool_request(p.tool, p.tool_request_part) + + tool_response_part, interrupt_part = await _chain_tool_middleware(mw_list, params, next_fn) + else: + tool_response_part, interrupt_part = await _resolve_tool_request(tool, tool_req_root) if tool_response_part: # Extract the ToolResponsePart from the returned Part for _to_pending_response diff --git a/py/packages/genkit/src/genkit/_ai/_middleware.py b/py/packages/genkit/src/genkit/_ai/_middleware.py deleted file mode 100644 index 54931710ea..0000000000 --- a/py/packages/genkit/src/genkit/_ai/_middleware.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Middleware for the Genkit framework.""" - -from collections.abc import Awaitable, Callable - -from genkit._ai._model import ( - Message, - ModelMiddleware, - ModelRequest, - ModelResponse, - text_from_content, -) -from genkit._core._action import ActionRunContext -from genkit._core._model import Document -from genkit._core._typing import ( - Part, - TextPart, -) - -CONTEXT_PREFACE = '\n\nUse the following information to complete your task:\n\n' - - -def context_item_template(d: Document, index: int) -> str: - """Render a document as a citation line for context injection.""" - out = '- ' - ref = (d.metadata and (d.metadata.get('ref') or d.metadata.get('id'))) or index - out += f'[{ref}]: ' - out += text_from_content(d.content) + '\n' - return out - - -def augment_with_context() -> ModelMiddleware: - """Middleware that injects document context into the last user message.""" - - async def middleware( - req: ModelRequest, - ctx: ActionRunContext, - next_middleware: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - if not req.docs: - return await next_middleware(req, ctx) - - user_message = last_user_message(req.messages) - if not user_message: - return await next_middleware(req, ctx) - - context_part_index = -1 - for i, part in enumerate(user_message.content): - part_metadata = part.root.metadata - if isinstance(part_metadata, dict) and part_metadata.get('purpose') == 'context': - context_part_index = i - break - - context_part = user_message.content[context_part_index] if context_part_index >= 0 else None - - if context_part: - metadata = context_part.root.metadata - if not (isinstance(metadata, dict) and metadata.get('pending')): - return await next_middleware(req, ctx) - - out = CONTEXT_PREFACE - for i, doc_data in enumerate(req.docs): - doc = Document(content=doc_data.content, metadata=doc_data.metadata) - out += context_item_template(doc, i) - out += '\n' - - text_part = Part(root=TextPart(text=out, metadata={'purpose': 'context'})) - if context_part_index >= 0: - user_message.content[context_part_index] = text_part - else: - if not user_message.content: - user_message.content = [] - user_message.content.append(text_part) - - return await next_middleware(req, ctx) - - return middleware - - -def last_user_message(messages: list[Message]) -> Message | None: - """Find the last user message in a list.""" - for i in range(len(messages) - 1, -1, -1): - if messages[i].role == 'user': - return messages[i] - return None diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index 2d7767a0ee..7e759aaf60 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -39,11 +39,11 @@ ) from genkit._ai._model import ( Message, - ModelMiddleware, ModelRequest, ModelResponse, ModelResponseChunk, ) +from genkit._core._middleware._base import BaseMiddleware from genkit._core._action import Action, ActionKind, ActionRunContext, StreamingCallback, create_action_key from genkit._core._channel import Channel from genkit._core._error import GenkitError @@ -107,7 +107,7 @@ class PromptGenerateOptions(TypedDict, total=False): return_tool_requests: bool | None max_turns: int | None on_chunk: ModelStreamingCallback | None - use: list[ModelMiddleware] | None + use: list[BaseMiddleware] | None context: dict[str, Any] | None step_name: str | None metadata: dict[str, Any] | None @@ -186,7 +186,7 @@ class PromptConfig(BaseModel): metadata: dict[str, Any] | None = None tools: list[str] | None = None tool_choice: ToolChoice | None = None - use: list[ModelMiddleware] | None = None + use: list[BaseMiddleware] | None = None docs: list[Document] | None = None tool_responses: list[Part] | None = None resources: list[str] | None = None @@ -216,7 +216,7 @@ def __init__( metadata: dict[str, Any] | None = None, tools: list[str] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware] | None = None, docs: list[Document] | None = None, resources: list[str] | None = None, name: str | None = None, diff --git a/py/packages/genkit/src/genkit/_core/_middleware/__init__.py b/py/packages/genkit/src/genkit/_core/_middleware/__init__.py new file mode 100644 index 0000000000..2b0c8a49ae --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Internal middleware package. No exports - import from submodules.""" diff --git a/py/packages/genkit/src/genkit/_core/_middleware/_augment_with_context.py b/py/packages/genkit/src/genkit/_core/_middleware/_augment_with_context.py new file mode 100644 index 0000000000..82baea68a0 --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware/_augment_with_context.py @@ -0,0 +1,96 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""augment_with_context middleware.""" + +from collections.abc import Awaitable, Callable + +from genkit._core._model import Document, ModelResponse +from genkit._core._typing import Part, TextPart + +from ._base import BaseMiddleware, ModelHookParams +from ._utils import _CONTEXT_PREFACE, _context_item_template, _last_user_message + + +def augment_with_context( + preface: str | None = _CONTEXT_PREFACE, + item_template: Callable[[Document, int], str] | None = None, + citation_key: str | None = None, +) -> BaseMiddleware: + """Middleware that injects document context into the last user message.""" + return _AugmentWithContextMiddleware( + preface=preface, + item_template=item_template or _context_item_template, + citation_key=citation_key, + ) + + +class _AugmentWithContextMiddleware(BaseMiddleware): + def __init__( + self, + preface: str | None = _CONTEXT_PREFACE, + item_template: Callable[[Document, int], str] | None = None, + citation_key: str | None = None, + ) -> None: + self._preface = preface + self._item_template = item_template or _context_item_template + self._citation_key = citation_key + + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + req = params.request + if not req.docs: + return await next_fn(params) + + user_message = _last_user_message(req.messages) # type: ignore[arg-type] + if not user_message: + return await next_fn(params) + + context_part_index = -1 + for i, part in enumerate(user_message.content): + part_metadata = part.root.metadata if hasattr(part.root, 'metadata') else None + if isinstance(part_metadata, dict) and part_metadata.get('purpose') == 'context': + context_part_index = i + break + + context_part = user_message.content[context_part_index] if context_part_index >= 0 else None + + if context_part: + metadata = context_part.root.metadata if hasattr(context_part.root, 'metadata') else None + if not (isinstance(metadata, dict) and metadata.get('pending')): + return await next_fn(params) + + out = self._preface or '' + for i, doc_data in enumerate(req.docs): + doc = Document(content=doc_data.content, metadata=doc_data.metadata) + if self._citation_key and doc.metadata: + doc.metadata['ref'] = doc.metadata.get(self._citation_key, i) + out += self._item_template(doc, i) + out += '\n' + + text_part = Part(root=TextPart(text=out, metadata={'purpose': 'context'})) + + if context_part_index >= 0: + user_message.content[context_part_index] = text_part + else: + if not user_message.content: + user_message.content = [] + user_message.content.append(text_part) + + return await next_fn(params) diff --git a/py/packages/genkit/src/genkit/_core/_middleware/_base.py b/py/packages/genkit/src/genkit/_core/_middleware/_base.py new file mode 100644 index 0000000000..795d0b5ff9 --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware/_base.py @@ -0,0 +1,113 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Base middleware protocol, params, and default implementation.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import ClassVar, Protocol + +from pydantic import BaseModel, ConfigDict, Field + +from genkit._core._action import Action +from genkit._core._model import ModelRequest, ModelResponse, ModelResponseChunk +from genkit._core._typing import GenerateActionOptions, Part, ToolRequestPart + + +class Middleware(Protocol): + """Middleware with hooks for Generate loop, Model call, and Tool execution. + + Use [BaseMiddleware] as a base to implement only the hooks you need. + """ + + def wrap_generate( + self, + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ) -> Awaitable[ModelResponse]: + """Wrap each iteration of the tool loop (model call + optional tool resolution).""" + ... + + def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ) -> Awaitable[ModelResponse]: + """Wrap each model API call.""" + ... + + def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[tuple[Part | None, Part | None]]], + ) -> Awaitable[tuple[Part | None, Part | None]]: + """Wrap each tool execution.""" + ... + + +class GenerateHookParams(BaseModel): + """Params for the wrap_generate hook.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + + options: GenerateActionOptions + request: ModelRequest + iteration: int + + +class ModelHookParams(BaseModel): + """Params for the wrap_model hook.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + + request: ModelRequest + on_chunk: Callable[[ModelResponseChunk], None] | None = None + context: dict[str, object] = Field(default_factory=dict) + + +class ToolHookParams(BaseModel): + """Params for the wrap_tool hook.""" + + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + + tool_request_part: ToolRequestPart + tool: Action + + +class BaseMiddleware: + """Base middleware with pass-through defaults. Override only the hooks you need.""" + + def wrap_generate( + self, + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ) -> Awaitable[ModelResponse]: + return next_fn(params) + + def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ) -> Awaitable[ModelResponse]: + return next_fn(params) + + def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[tuple[Part | None, Part | None]]], + ) -> Awaitable[tuple[Part | None, Part | None]]: + return next_fn(params) diff --git a/py/packages/genkit/src/genkit/_core/_middleware/_download_request_media.py b/py/packages/genkit/src/genkit/_core/_middleware/_download_request_media.py new file mode 100644 index 0000000000..565aeb60ff --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware/_download_request_media.py @@ -0,0 +1,117 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""download_request_media middleware.""" + +import base64 +from collections.abc import Awaitable, Callable + +import httpx + +from genkit._ai._model import Message +from genkit._core._error import GenkitError +from genkit._core._model import ModelResponse +from genkit._core._typing import Media, MediaPart, Part + +from ._base import BaseMiddleware, ModelHookParams +from ._utils import _is_safe_url + + +def download_request_media( + max_bytes: int | None = None, + filter_fn: Callable[[Part], bool] | None = None, +) -> BaseMiddleware: + """Middleware that downloads HTTP media URLs and converts to base64 data URIs.""" + return _DownloadRequestMediaMiddleware(max_bytes=max_bytes, filter_fn=filter_fn) + + +class _DownloadRequestMediaMiddleware(BaseMiddleware): + def __init__( + self, + max_bytes: int | None = None, + filter_fn: Callable[[Part], bool] | None = None, + ) -> None: + self._max_bytes = max_bytes + self._filter_fn = filter_fn + + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + req = params.request + async with httpx.AsyncClient() as client: + new_messages: list[Message] = [] + + for msg in req.messages: + new_content: list[Part] = [] + content_changed = False + + for part in msg.content: + if isinstance(part.root, MediaPart) and part.root.media.url.startswith('http'): + if not _is_safe_url(part.root.media.url): + raise GenkitError( + status='INVALID_ARGUMENT', + message=f"Media URL is not allowed (SSRF protection): '{part.root.media.url}'", + ) + if self._filter_fn is not None and not self._filter_fn(part): + new_content.append(part) + continue + + content_changed = True + try: + response = await client.get(part.root.media.url) + response.raise_for_status() + + content = response.content + if self._max_bytes is not None and len(content) > self._max_bytes: + content = content[: self._max_bytes] + + content_type = part.root.media.content_type or response.headers.get( + 'content-type', 'application/octet-stream' + ) + + b64_data = base64.b64encode(content).decode('utf-8') + data_uri = f'data:{content_type};base64,{b64_data}' + + new_part = Part( + root=MediaPart( + media=Media(url=data_uri, content_type=content_type), + ) + ) + new_content.append(new_part) + + except httpx.HTTPError as e: + raise GenkitError( + status='INVALID_ARGUMENT', + message=f"Failed to download media from '{part.root.media.url}': {e}", + ) from e + else: + new_content.append(part) + + if content_changed: + new_messages.append(Message(role=msg.role, content=new_content, metadata=msg.metadata)) + else: + new_messages.append(msg) + + new_req = req.model_copy(update={'messages': new_messages}) + return await next_fn( + ModelHookParams( + request=new_req, + on_chunk=params.on_chunk, + context=params.context, + ) + ) diff --git a/py/packages/genkit/src/genkit/_core/_middleware/_fallback.py b/py/packages/genkit/src/genkit/_core/_middleware/_fallback.py new file mode 100644 index 0000000000..5ddbb2e672 --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware/_fallback.py @@ -0,0 +1,121 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""fallback middleware.""" + +from collections.abc import Awaitable, Callable +from typing import Protocol + +from genkit._core._error import GenkitError, StatusName +from genkit._core._model import ModelResponse +from genkit._core._registry import Registry + +from ._base import BaseMiddleware, ModelHookParams +from ._utils import _DEFAULT_FALLBACK_STATUSES + + +class HasRegistry(Protocol): + """Protocol for objects that have a registry (e.g. Genkit instance).""" + + registry: Registry + + +def fallback( + ai: HasRegistry, + models: list[str], + statuses: list[StatusName] | None = None, + on_error: Callable[[GenkitError], None] | None = None, +) -> BaseMiddleware: + """Middleware that falls back to alternative models on failure. + + Args: + ai: Object with a registry (e.g. Genkit instance) for resolving fallback models. + models: Ordered list of fallback model names to try. + statuses: List of status codes that trigger fallback (default: UNAVAILABLE, + DEADLINE_EXCEEDED, RESOURCE_EXHAUSTED, ABORTED, INTERNAL, NOT_FOUND, + UNIMPLEMENTED). + on_error: Optional callback called when fallback is triggered. + + Returns: + Middleware that implements fallback logic. + """ + return _fallback_for_registry(ai.registry, models, statuses, on_error) + + +def _fallback_for_registry( + registry: Registry, + models: list[str], + statuses: list[StatusName] | None = None, + on_error: Callable[[GenkitError], None] | None = None, +) -> BaseMiddleware: + """Internal: fallback middleware that takes a Registry (for testing).""" + return _FallbackMiddleware( + registry=registry, + models=models, + statuses=statuses or _DEFAULT_FALLBACK_STATUSES, + on_error=on_error, + ) + + +class _FallbackMiddleware(BaseMiddleware): + def __init__( + self, + registry: Registry, + models: list[str], + statuses: list[StatusName], + on_error: Callable[[GenkitError], None] | None = None, + ) -> None: + self._registry = registry + self._models = models + self._statuses = statuses + self._on_error = on_error + + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + try: + return await next_fn(params) + except Exception as e: + if isinstance(e, GenkitError) and e.status in self._statuses: + if self._on_error: + self._on_error(e) + + last_error: Exception = e + for model_name in self._models: + try: + model = await self._registry.resolve_model(model_name) + if model is None: + raise GenkitError( + status='NOT_FOUND', + message=f"Fallback model '{model_name}' not found.", + ) + result = await model.run( + input=params.request, + context=params.context, + on_chunk=params.on_chunk, + ) + return result.response + except Exception as e2: + last_error = e2 + if isinstance(e2, GenkitError) and e2.status in self._statuses: + if self._on_error: + self._on_error(e2) + continue + raise + raise last_error from None + raise diff --git a/py/packages/genkit/src/genkit/_core/_middleware/_retry.py b/py/packages/genkit/src/genkit/_core/_middleware/_retry.py new file mode 100644 index 0000000000..e56512a08e --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware/_retry.py @@ -0,0 +1,120 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""retry middleware.""" + +import asyncio +import random +from collections.abc import Awaitable, Callable + +from genkit._core._error import GenkitError, StatusName +from genkit._core._model import ModelResponse + +from ._base import BaseMiddleware, ModelHookParams +from ._utils import _DEFAULT_RETRY_STATUSES + + +def retry( + max_retries: int = 3, + statuses: list[StatusName] | None = None, + initial_delay_ms: int = 1000, + max_delay_ms: int = 60000, + backoff_factor: float = 2.0, + jitter: bool = True, + on_error: Callable[[Exception, int], None] | None = None, +) -> BaseMiddleware: + """Middleware that retries failed requests with exponential backoff. + + Args: + max_retries: Maximum number of retry attempts (default: 3). + statuses: List of status codes that trigger retry (default: UNAVAILABLE, + DEADLINE_EXCEEDED, RESOURCE_EXHAUSTED, ABORTED, INTERNAL). + initial_delay_ms: Initial delay between retries in milliseconds (default: 1000). + max_delay_ms: Maximum delay between retries in milliseconds (default: 60000). + backoff_factor: Multiplier for delay after each retry (default: 2.0). + jitter: Whether to add random jitter to delays (default: True). + on_error: Optional callback called on each retry attempt with (error, attempt). + + Returns: + Middleware that implements retry logic. + """ + return _RetryMiddleware( + max_retries=max_retries, + statuses=statuses or _DEFAULT_RETRY_STATUSES, + initial_delay_ms=initial_delay_ms, + max_delay_ms=max_delay_ms, + backoff_factor=backoff_factor, + jitter=jitter, + on_error=on_error, + ) + + +class _RetryMiddleware(BaseMiddleware): + def __init__( + self, + max_retries: int = 3, + statuses: list[StatusName] | None = None, + initial_delay_ms: int = 1000, + max_delay_ms: int = 60000, + backoff_factor: float = 2.0, + jitter: bool = True, + on_error: Callable[[Exception, int], None] | None = None, + ) -> None: + self._max_retries = max_retries + self._statuses = statuses or _DEFAULT_RETRY_STATUSES + self._initial_delay_ms = initial_delay_ms + self._max_delay_ms = max_delay_ms + self._backoff_factor = backoff_factor + self._jitter = jitter + self._on_error = on_error + + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + last_error: Exception | None = None + current_delay_ms: float = float(self._initial_delay_ms) + + for attempt in range(int(self._max_retries) + 1): + try: + return await next_fn(params) + except Exception as e: + last_error = e + + if attempt < self._max_retries: + should_retry = isinstance(e, GenkitError) and e.status in self._statuses + + if should_retry: + if self._on_error: + self._on_error(e, attempt + 1) + + delay = current_delay_ms + if self._jitter: + delay = delay + random.random() * (2**attempt) * 1000 + + await asyncio.sleep(delay / 1000.0) + current_delay_ms = min( + current_delay_ms * self._backoff_factor, + float(self._max_delay_ms), + ) + continue + + raise + + if last_error: + raise last_error + raise RuntimeError('Retry loop completed without result') diff --git a/py/packages/genkit/src/genkit/_core/_middleware/_simulate_system_prompt.py b/py/packages/genkit/src/genkit/_core/_middleware/_simulate_system_prompt.py new file mode 100644 index 0000000000..08520c57ac --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware/_simulate_system_prompt.py @@ -0,0 +1,86 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""simulate_system_prompt middleware.""" + +from collections.abc import Awaitable, Callable + +from genkit._ai._model import Message +from genkit._core._model import ModelResponse +from genkit._core._typing import Part, TextPart + +from ._base import BaseMiddleware, ModelHookParams + + +def simulate_system_prompt( + preface: str = 'SYSTEM INSTRUCTIONS:\n', + acknowledgement: str = 'Understood.', +) -> BaseMiddleware: + r"""Middleware that simulates system prompt for models without native support. + + Converts system messages to user+model message pairs. + + Args: + preface: Text to prepend to the system content. + acknowledgement: Model's acknowledgement response. + + Returns: + Middleware that transforms system messages. + """ + return _SimulateSystemPromptMiddleware(preface=preface, acknowledgement=acknowledgement) + + +class _SimulateSystemPromptMiddleware(BaseMiddleware): + def __init__( + self, + preface: str = 'SYSTEM INSTRUCTIONS:\n', + acknowledgement: str = 'Understood.', + ) -> None: + self._preface = preface + self._acknowledgement = acknowledgement + + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + req = params.request + new_messages: list[Message] = [] + system_found = False + + for msg in req.messages: + if msg.role == 'system' and not system_found: + user_content: list[Part] = [Part(root=TextPart(text=self._preface))] + user_content.extend(msg.content) + new_messages.append(Message(role='user', content=user_content)) + new_messages.append( + Message( + role='model', + content=[Part(root=TextPart(text=self._acknowledgement))], + ) + ) + system_found = True + else: + new_messages.append(msg) + + new_req = req.model_copy(update={'messages': new_messages}) + return await next_fn( + ModelHookParams( + request=new_req, + on_chunk=params.on_chunk, + context=params.context, + ) + ) diff --git a/py/packages/genkit/src/genkit/_core/_middleware/_utils.py b/py/packages/genkit/src/genkit/_core/_middleware/_utils.py new file mode 100644 index 0000000000..10ea76aa4b --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware/_utils.py @@ -0,0 +1,83 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Shared middleware utilities and constants.""" + +import ipaddress +from urllib.parse import urlparse + +from genkit._ai._model import Message, text_from_content +from genkit._core._error import StatusName +from genkit._core._model import Document + +_CONTEXT_PREFACE = '\n\nUse the following information to complete your task:\n\n' + +_DEFAULT_RETRY_STATUSES: list[StatusName] = [ + 'UNAVAILABLE', + 'DEADLINE_EXCEEDED', + 'RESOURCE_EXHAUSTED', + 'ABORTED', + 'INTERNAL', +] + +_DEFAULT_FALLBACK_STATUSES: list[StatusName] = [ + 'UNAVAILABLE', + 'DEADLINE_EXCEEDED', + 'RESOURCE_EXHAUSTED', + 'ABORTED', + 'INTERNAL', + 'NOT_FOUND', + 'UNIMPLEMENTED', +] + +_SSRF_BLOCKED_HOSTNAMES: frozenset[str] = frozenset(('metadata.google.internal', 'metadata', '169.254.169.254')) + + +def _is_safe_url(url: str) -> bool: + """Check if URL is safe for download (blocks SSRF: private IPs, loopback, cloud metadata).""" + try: + parsed = urlparse(url) + hostname = parsed.hostname + if not hostname: + return False + host_lower = hostname.lower() + blocked = ('localhost', 'localhost.', 'ip6-localhost', 'ip6-loopback') + if host_lower in blocked or host_lower in _SSRF_BLOCKED_HOSTNAMES: + return False + try: + addr = ipaddress.ip_address(hostname) + except ValueError: + return True # Hostname (e.g. example.com); caller can use filter_fn to restrict + return not (addr.is_private or addr.is_loopback or addr.is_link_local) + except Exception: + return False + + +def _last_user_message(messages: list[Message]) -> Message | None: + """Find the last user message in a list.""" + for i in range(len(messages) - 1, -1, -1): + if messages[i].role == 'user': + return messages[i] + return None + + +def _context_item_template(d: Document, index: int) -> str: + """Render a document as a citation line for context injection.""" + out = '- ' + ref = (d.metadata and (d.metadata.get('ref') or d.metadata.get('id'))) or index + out += f'[{ref}]: ' + out += text_from_content(d.content) + '\n' + return out diff --git a/py/packages/genkit/src/genkit/_core/_middleware/_validate_support.py b/py/packages/genkit/src/genkit/_core/_middleware/_validate_support.py new file mode 100644 index 0000000000..3fb58762d5 --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware/_validate_support.py @@ -0,0 +1,98 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""validate_support middleware.""" + +from collections.abc import Awaitable, Callable + +from genkit._core._error import GenkitError +from genkit._core._model import ModelResponse +from genkit._core._typing import MediaPart, Supports + +from ._base import BaseMiddleware, ModelHookParams + + +def validate_support( + name: str, + supports: Supports | None = None, +) -> BaseMiddleware: + """Middleware that validates request against model capabilities. + + Args: + name: The model name (for error messages). + supports: The model's capability flags. + + Returns: + Middleware that validates requests. + + Raises: + GenkitError: With INVALID_ARGUMENT status if validation fails. + """ + return _ValidateSupportMiddleware(name=name, supports=supports) + + +class _ValidateSupportMiddleware(BaseMiddleware): + def __init__(self, name: str, supports: Supports | None = None) -> None: + self._name = name + self._supports = supports + + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + req = params.request + if self._supports is None: + return await next_fn(params) + + if self._supports.media is False: + for msg in req.messages: + for part in msg.content: + if isinstance(part.root, MediaPart) and part.root.media is not None: + raise GenkitError( + status='INVALID_ARGUMENT', + message=f"Model '{self._name}' does not support media, but media was provided.", + ) + + if self._supports.tools is False and req.tools: + raise GenkitError( + status='INVALID_ARGUMENT', + message=f"Model '{self._name}' does not support tool use, but tools were provided.", + ) + + if self._supports.multiturn is False and len(req.messages) > 1: + raise GenkitError( + status='INVALID_ARGUMENT', + message=( + f"Model '{self._name}' does not support multiple messages, but {len(req.messages)} were provided." + ), + ) + + if self._supports.system_role is False: + for msg in req.messages: + if msg.role == 'system': + raise GenkitError( + status='INVALID_ARGUMENT', + message=f"Model '{self._name}' does not support system role, but system role was provided.", + ) + + if self._supports.tool_choice is False and req.tool_choice and req.tool_choice != 'auto': + raise GenkitError( + status='INVALID_ARGUMENT', + message=f"Model '{self._name}' does not support tool choice, but tool choice was provided.", + ) + + return await next_fn(params) diff --git a/py/packages/genkit/src/genkit/middleware/__init__.py b/py/packages/genkit/src/genkit/middleware/__init__.py new file mode 100644 index 0000000000..14d07d65c8 --- /dev/null +++ b/py/packages/genkit/src/genkit/middleware/__init__.py @@ -0,0 +1,69 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Middleware for Genkit model calls. + +This module provides middleware that can be used to modify model requests and +responses, add retry logic, implement fallback behavior, and more. + +Chain ordering: middleware is applied first-in, outermost. The first middleware +in the list wraps around the rest; calls flow in → out. + +Example usage: + from genkit import Genkit + from genkit.middleware import retry, fallback + + ai = Genkit() + + response = await ai.generate( + model="gemini-pro", + prompt="Hello", + use=[ + retry(max_retries=3), + fallback(ai, models=["gemini-flash"]), + ], + ) +""" + +from genkit._core._middleware._augment_with_context import augment_with_context +from genkit._core._middleware._base import ( + BaseMiddleware, + GenerateHookParams, + ModelHookParams, + ToolHookParams, +) +from genkit._core._middleware._download_request_media import ( + download_request_media, +) +from genkit._core._middleware._fallback import fallback +from genkit._core._middleware._retry import retry +from genkit._core._middleware._simulate_system_prompt import ( + simulate_system_prompt, +) +from genkit._core._middleware._validate_support import validate_support + +__all__ = [ + 'BaseMiddleware', + 'GenerateHookParams', + 'ModelHookParams', + 'ToolHookParams', + 'augment_with_context', + 'download_request_media', + 'fallback', + 'retry', + 'simulate_system_prompt', + 'validate_support', +] diff --git a/py/packages/genkit/tests/genkit/ai/generate_test.py b/py/packages/genkit/tests/genkit/ai/generate_test.py index feffd1a990..6aacbd9090 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_test.py @@ -16,13 +16,18 @@ from genkit import ActionKind, Document, Genkit, Message, ModelResponse, ModelResponseChunk from genkit._ai._generate import generate_action +from genkit.middleware import ( + BaseMiddleware, + GenerateHookParams, + ModelHookParams, + ToolHookParams, +) from genkit._ai._model import text_from_content, text_from_message from genkit._ai._testing import ( ProgrammableModel, define_echo_model, define_programmable_model, ) -from genkit._core._action import ActionRunContext from genkit._core._model import ModelRequest from genkit._core._typing import ( DocumentPart, @@ -31,6 +36,8 @@ Part, Role, TextPart, + ToolRequest, + ToolRequestPart, ) @@ -144,35 +151,25 @@ async def test_simulates_doc_grounding( ) -@pytest.mark.asyncio -async def test_generate_applies_middleware( - setup_test: tuple[Genkit, ProgrammableModel], -) -> None: - """When middleware is provided, apply it.""" - ai, *_ = setup_test - define_echo_model(ai) - - async def pre_middle( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message(role=Role.USER, content=[Part(TextPart(text=f'PRE {txt}'))]), - ], - ), - ctx, +class PreMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + txt = ''.join(text_from_message(m) for m in params.request.messages) + return await next_fn( + ModelHookParams( + request=ModelRequest( + messages=[ + Message(role=Role.USER, content=[Part(TextPart(text=f'PRE {txt}'))]), + ], + ), + on_chunk=params.on_chunk, + context=params.context, + ) ) - async def post_middle( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - resp: ModelResponse = await next(req, ctx) + +class PostMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + resp: ModelResponse = await next_fn(params) assert resp.message is not None txt = text_from_message(resp.message) return ModelResponse( @@ -180,6 +177,15 @@ async def post_middle( message=Message(role=Role.USER, content=[Part(TextPart(text=f'{txt} POST'))]), ) + +@pytest.mark.asyncio +async def test_generate_applies_middleware( + setup_test: tuple[Genkit, ProgrammableModel], +) -> None: + """When middleware is provided, apply it.""" + ai, *_ = setup_test + define_echo_model(ai) + response = await generate_action( ai.registry, GenerateActionOptions( @@ -191,7 +197,7 @@ async def post_middle( ), ], ), - middleware=[pre_middle, post_middle], + middleware=[PreMiddleware(), PostMiddleware()], ) assert response.text == '[ECHO] user: "PRE hi" POST' @@ -205,19 +211,6 @@ async def test_generate_middleware_next_fn_args_optional( ai, *_ = setup_test define_echo_model(ai) - async def post_middle( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - resp: ModelResponse = await next(req, ctx) - assert resp.message is not None - txt = text_from_message(resp.message) - return ModelResponse( - finish_reason=resp.finish_reason, - message=Message(role=Role.USER, content=[Part(TextPart(text=f'{txt} POST'))]), - ) - response = await generate_action( ai.registry, GenerateActionOptions( @@ -229,12 +222,42 @@ async def post_middle( ), ], ), - middleware=[post_middle], + middleware=[PostMiddleware()], ) assert response.text == '[ECHO] user: "hi" POST' +class AddContextMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + return await next_fn( + ModelHookParams( + request=params.request, + on_chunk=params.on_chunk, + context={**params.context, 'banana': True}, + ) + ) + + +class InjectContextMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + txt = ''.join(text_from_message(m) for m in params.request.messages) + return await next_fn( + ModelHookParams( + request=ModelRequest( + messages=[ + Message( + role=Role.USER, + content=[Part(TextPart(text=f'{txt} {params.context}'))], + ), + ], + ), + on_chunk=params.on_chunk, + context=params.context, + ) + ) + + @pytest.mark.asyncio async def test_generate_middleware_can_modify_context( setup_test: tuple[Genkit, ProgrammableModel], @@ -243,31 +266,6 @@ async def test_generate_middleware_can_modify_context( ai, *_ = setup_test define_echo_model(ai) - async def add_context( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - return await next(req, ActionRunContext(context={**ctx.context, 'banana': True})) - - async def inject_context( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[Part(TextPart(text=f'{txt} {ctx.context}'))], - ), - ], - ), - ctx, - ) - response = await generate_action( ai.registry, GenerateActionOptions( @@ -279,7 +277,7 @@ async def inject_context( ), ], ), - middleware=[add_context, inject_context], + middleware=[AddContextMiddleware(), InjectContextMiddleware()], context={'foo': 'bar'}, ) @@ -307,44 +305,44 @@ async def test_generate_middleware_can_modify_stream( ] ] - async def modify_stream( - req: ModelRequest, - ctx: ActionRunContext, - on_chunk: Callable[[ModelResponseChunk], None] | None, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - # 4-param streaming middleware signature - if on_chunk: - on_chunk( - ModelResponseChunk( - role=Role.MODEL, - content=[Part(TextPart(text='something extra before'))], - ) - ) + got_chunks = [] - def chunk_handler(chunk: ModelResponseChunk) -> None: - if on_chunk: - on_chunk( + def collect_chunks(c: ModelResponseChunk) -> None: + got_chunks.append(text_from_content(c.content)) + + class ModifyStreamMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + if params.on_chunk: + params.on_chunk( ModelResponseChunk( role=Role.MODEL, - content=[Part(TextPart(text=f'intercepted: {text_from_content(chunk.content)}'))], + content=[Part(TextPart(text='something extra before'))], ) ) - resp = await next(req, ctx, chunk_handler) - if on_chunk: - on_chunk( - ModelResponseChunk( - role=Role.MODEL, - content=[Part(TextPart(text='something extra after'))], - ) - ) - return resp - - got_chunks = [] + def chunk_handler(chunk: ModelResponseChunk) -> None: + if params.on_chunk: + params.on_chunk( + ModelResponseChunk( + role=Role.MODEL, + content=[Part(TextPart(text=f'intercepted: {text_from_content(chunk.content)}'))], + ) + ) - def collect_chunks(c: ModelResponseChunk) -> None: - got_chunks.append(text_from_content(c.content)) + new_params = ModelHookParams( + request=params.request, + on_chunk=chunk_handler, + context=params.context, + ) + resp = await next_fn(new_params) + if params.on_chunk: + params.on_chunk( + ModelResponseChunk( + role=Role.MODEL, + content=[Part(TextPart(text='something extra after'))], + ) + ) + return resp response = await generate_action( ai.registry, @@ -357,7 +355,7 @@ def collect_chunks(c: ModelResponseChunk) -> None: ), ], ), - middleware=[modify_stream], + middleware=[ModifyStreamMiddleware()], on_chunk=collect_chunks, ) @@ -371,6 +369,133 @@ def collect_chunks(c: ModelResponseChunk) -> None: ] +class TrackGenerateMiddleware(BaseMiddleware): + """Middleware that records wrap_generate calls per turn.""" + + def __init__(self) -> None: + self.iterations: list[int] = [] + + async def wrap_generate( + self, + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + self.iterations.append(params.iteration) + return await next_fn(params) + + +@pytest.mark.asyncio +async def test_wrap_generate_called_per_turn( + setup_test: tuple[Genkit, ProgrammableModel], +) -> None: + """wrap_generate is invoked for each turn of the generate loop.""" + ai, pm = setup_test + define_echo_model(ai) + + track_mw = TrackGenerateMiddleware() + + # No tools: single turn, wrap_generate called once with iteration=0 + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='done'))]), + ) + ) + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='hi'))])], + ), + middleware=[track_mw], + ) + assert response.text == 'done' + assert track_mw.iterations == [0] + + # With tools: two turns (model->tool->model), wrap_generate called for each + track_mw2 = TrackGenerateMiddleware() + pm.responses.append( + ModelResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=ToolRequestPart(tool_request=ToolRequest(name='testTool', input={}, ref='r1')))], + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='final'))]), + ) + ) + response2 = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='hi'))])], + tools=['testTool'], + ), + middleware=[track_mw2], + ) + assert response2.text == 'final' + assert track_mw2.iterations == [0, 1] + + +class TrackToolMiddleware(BaseMiddleware): + """Middleware that records wrap_tool calls.""" + + def __init__(self) -> None: + self.tool_names: list[str] = [] + + async def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[tuple['Part | None', 'Part | None']]], + ) -> tuple['Part | None', 'Part | None']: + self.tool_names.append(params.tool_request_part.tool_request.name) + return await next_fn(params) + + +@pytest.mark.asyncio +async def test_wrap_tool_called_on_tool_execution( + setup_test: tuple[Genkit, ProgrammableModel], +) -> None: + """wrap_tool is invoked for each tool execution.""" + ai, pm = setup_test + + @ai.tool(name='myTool') + async def my_tool() -> object: + return 'result' + + pm.responses.append( + ModelResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=ToolRequestPart(tool_request=ToolRequest(name='myTool', input={}, ref='r1')))], + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='done'))]), + ) + ) + + track_mw = TrackToolMiddleware() + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='hi'))])], + tools=['myTool'], + ), + middleware=[track_mw], + ) + assert response.text == 'done' + assert track_mw.tool_names == ['myTool'] + + ########################################################################## # run tests from /tests/specs/generate.yaml ########################################################################## diff --git a/py/packages/genkit/tests/genkit/ai/middleware_test.py b/py/packages/genkit/tests/genkit/ai/middleware_test.py index fae6fe1c09..5043c2580a 100644 --- a/py/packages/genkit/tests/genkit/ai/middleware_test.py +++ b/py/packages/genkit/tests/genkit/ai/middleware_test.py @@ -22,8 +22,8 @@ import pytest from genkit import Document, Message, ModelResponse -from genkit._ai._middleware import augment_with_context -from genkit._core._action import ActionRunContext +from genkit.middleware import augment_with_context +from genkit._core._middleware._base import ModelHookParams from genkit._core._model import ModelRequest from genkit._core._typing import ( DocumentPart, @@ -38,11 +38,14 @@ async def run_augmenter(req: ModelRequest) -> ModelRequest: augmenter = augment_with_context() req_future = asyncio.Future() - async def next(req: ModelRequest, _: ActionRunContext) -> ModelResponse: - req_future.set_result(req) + async def next_fn(params: ModelHookParams) -> ModelResponse: + req_future.set_result(params.request) return ModelResponse(message=Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))])) - await augmenter(req, ActionRunContext(), next) + await augmenter.wrap_model( + ModelHookParams(request=req, on_chunk=None, context={}), + next_fn, + ) return req_future.result() diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py index 0310c779f2..25f8d8cb37 100644 --- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py @@ -31,6 +31,7 @@ ) from genkit._core._action import ActionKind, ActionRunContext from genkit._core._model import ModelRequest +from genkit.middleware import BaseMiddleware, ModelHookParams from genkit._core._typing import ( BaseDataPoint, Details, @@ -998,30 +999,25 @@ class TestSchema(BaseModel): assert (await stream_result.response).request == want -@pytest.mark.asyncio -async def test_generate_with_middleware( - setup_test: SetupFixture, -) -> None: - """When middleware is provided, applies it.""" - ai, *_ = setup_test - - async def pre_middle( - req: ModelRequest, ctx: ActionRunContext, next: Callable[..., Awaitable[ModelResponse]] - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message(role=Role.USER, content=[Part(root=TextPart(text=f'PRE {txt}'))]), - ], - ), - ctx, +class PreMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + txt = ''.join(text_from_message(m) for m in params.request.messages) + return await next_fn( + ModelHookParams( + request=ModelRequest( + messages=[ + Message(role=Role.USER, content=[Part(root=TextPart(text=f'PRE {txt}'))]), + ], + ), + on_chunk=params.on_chunk, + context=params.context, + ) ) - async def post_middle( - req: ModelRequest, ctx: ActionRunContext, next: Callable[..., Awaitable[ModelResponse]] - ) -> ModelResponse: - resp: ModelResponse = await next(req, ctx) + +class PostMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + resp: ModelResponse = await next_fn(params) assert resp.message is not None txt = text_from_message(resp.message) return ModelResponse( @@ -1029,17 +1025,44 @@ async def post_middle( message=Message(role=Role.USER, content=[Part(root=TextPart(text=f'{txt} POST'))]), ) + +@pytest.mark.asyncio +async def test_generate_with_middleware( + setup_test: SetupFixture, +) -> None: + """When middleware is provided, applies it.""" + ai, *_ = setup_test + want = '[ECHO] user: "PRE hi" POST' - response = await ai.generate(model='echoModel', prompt='hi', use=[pre_middle, post_middle]) + response = await ai.generate(model='echoModel', prompt='hi', use=[PreMiddleware(), PostMiddleware()]) assert response.text == want - stream_result = ai.generate_stream(model='echoModel', prompt='hi', use=[pre_middle, post_middle]) + stream_result = ai.generate_stream(model='echoModel', prompt='hi', use=[PreMiddleware(), PostMiddleware()]) assert (await stream_result.response).text == want +class InjectContextMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + txt = ''.join(text_from_message(m) for m in params.request.messages) + return await next_fn( + ModelHookParams( + request=ModelRequest( + messages=[ + Message( + role=Role.USER, + content=[Part(root=TextPart(text=f'{txt} {params.context}'))], + ), + ], + ), + on_chunk=params.on_chunk, + context=params.context, + ) + ) + + @pytest.mark.asyncio async def test_generate_passes_through_current_action_context( setup_test: SetupFixture, @@ -1047,24 +1070,8 @@ async def test_generate_passes_through_current_action_context( """Test that generate uses current action context by default.""" ai, *_ = setup_test - async def inject_context( - req: ModelRequest, ctx: ActionRunContext, next: Callable[..., Awaitable[ModelResponse]] - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[Part(root=TextPart(text=f'{txt} {ctx.context}'))], - ), - ], - ), - ctx, - ) - async def action_fn() -> ModelResponse: - return await ai.generate(model='echoModel', prompt='hi', use=[inject_context]) + return await ai.generate(model='echoModel', prompt='hi', use=[InjectContextMiddleware()]) action = ai.registry.register_action(name='test_action', kind=ActionKind.CUSTOM, fn=action_fn) action_response = await action.run(context={'foo': 'bar'}) @@ -1079,27 +1086,11 @@ async def test_generate_uses_explicitly_passed_in_context( """Generate uses specific context instead of current action context.""" ai, *_ = setup_test - async def inject_context( - req: ModelRequest, ctx: ActionRunContext, next: Callable[..., Awaitable[ModelResponse]] - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[Part(root=TextPart(text=f'{txt} {ctx.context}'))], - ), - ], - ), - ctx, - ) - async def action_fn() -> ModelResponse: return await ai.generate( model='echoModel', prompt='hi', - use=[inject_context], + use=[InjectContextMiddleware()], context={'bar': 'baz'}, ) diff --git a/py/samples/middleware/src/main.py b/py/samples/middleware/src/main.py index fe4496bbe3..f437bf968b 100644 --- a/py/samples/middleware/src/main.py +++ b/py/samples/middleware/src/main.py @@ -21,8 +21,8 @@ import structlog from pydantic import BaseModel, Field -from genkit import Genkit, Message, ModelRequest, ModelResponse, Part, Role, TextPart -from genkit._core._action import ActionRunContext +from genkit import Genkit, Message, Part, Role, TextPart +from genkit.middleware import BaseMiddleware, ModelHookParams from genkit.plugins.google_genai import GoogleAI logger = structlog.get_logger(__name__) @@ -32,41 +32,56 @@ class PromptInput(BaseModel): """Input shared by middleware flows.""" - prompt: str = Field(default='Explain recursion simply.', description='Prompt to send to the model') + prompt: str = Field( + default='Explain recursion simply.', + description='Prompt to send to the model', + ) -async def logging_middleware( - req: ModelRequest, - ctx: ActionRunContext, - next_handler: Callable[[ModelRequest, ActionRunContext], Awaitable[ModelResponse]], -) -> ModelResponse: +class LoggingMiddleware(BaseMiddleware): """Log request/response details without changing behavior.""" - await logger.ainfo('middleware saw request', message_count=len(req.messages)) - response = await next_handler(req, ctx) - await logger.ainfo('middleware saw response', finish_reason=response.finish_reason) - return response - - -async def concise_reply_middleware( - req: ModelRequest, - ctx: ActionRunContext, - next_handler: Callable[[ModelRequest, ActionRunContext], Awaitable[ModelResponse]], -) -> ModelResponse: + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable], + ): + await logger.ainfo('middleware saw request', message_count=len(params.request.messages)) + response = await next_fn(params) + await logger.ainfo( + 'middleware saw response', + finish_reason=response.finish_reason, + ) + return response + + +class ConciseReplyMiddleware(BaseMiddleware): """Add a short system instruction before the model call.""" - system_message = Message( - role=Role.SYSTEM, - content=[Part(root=TextPart(text='Answer in one short paragraph.'))], - ) - return await next_handler(req.model_copy(update={'messages': [system_message, *req.messages]}), ctx) + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable], + ): + system_message = Message( + role=Role.SYSTEM, + content=[Part(root=TextPart(text='Answer in one short paragraph.'))], + ) + new_req = params.request.model_copy(update={'messages': [system_message, *params.request.messages]}) + return await next_fn( + ModelHookParams( + request=new_req, + on_chunk=params.on_chunk, + context=params.context, + ) + ) @ai.flow() async def logging_demo(input: PromptInput) -> str: """Run a prompt through a read-only middleware.""" - response = await ai.generate(prompt=input.prompt, use=[logging_middleware]) + response = await ai.generate(prompt=input.prompt, use=[LoggingMiddleware()]) return response.text @@ -74,7 +89,7 @@ async def logging_demo(input: PromptInput) -> str: async def request_modifier_demo(input: PromptInput) -> str: """Run a prompt through a request-modifying middleware.""" - response = await ai.generate(prompt=input.prompt, use=[concise_reply_middleware]) + response = await ai.generate(prompt=input.prompt, use=[ConciseReplyMiddleware()]) return response.text @@ -84,7 +99,9 @@ async def main() -> None: print(await logging_demo(PromptInput())) # noqa: T201 print(await request_modifier_demo(PromptInput(prompt='Write a haiku about recursion.'))) # noqa: T201 except Exception as error: - print(f'Set GEMINI_API_KEY to a valid value before running this sample directly.\n{error}') # noqa: T201 + print( + f'Set GEMINI_API_KEY to a valid value before running this sample directly.\n{error}' # noqa: T201 + ) if __name__ == '__main__':