|
14 | 14 | from fastmcp.prompts import Message |
15 | 15 | from fastmcp.server import FastMCP |
16 | 16 | from fastmcp.server.context import Context as FastMCPContext |
17 | | -from pydantic import Field |
| 17 | +from pydantic import BaseModel, Field, create_model |
18 | 18 | from qdrant_client import models |
19 | 19 | from qdrant_client.async_qdrant_client import AsyncQdrantClient |
20 | 20 | from starlette.requests import Request |
@@ -99,6 +99,36 @@ def reranker(self) -> CrossEncoder | None: |
99 | 99 | server = PlexServer(settings=settings) |
100 | 100 |
|
101 | 101 |
|
| 102 | +def _request_model(name: str, fn: Callable[..., Any]) -> type[BaseModel] | None: |
| 103 | + """Generate a Pydantic model representing the callable's parameters.""" |
| 104 | + |
| 105 | + signature = inspect.signature(fn) |
| 106 | + if not signature.parameters: |
| 107 | + return None |
| 108 | + |
| 109 | + fields: dict[str, tuple[Any, Any]] = {} |
| 110 | + for param_name, parameter in signature.parameters.items(): |
| 111 | + annotation = ( |
| 112 | + parameter.annotation |
| 113 | + if parameter.annotation is not inspect._empty |
| 114 | + else Any |
| 115 | + ) |
| 116 | + default = ( |
| 117 | + parameter.default |
| 118 | + if parameter.default is not inspect._empty |
| 119 | + else ... |
| 120 | + ) |
| 121 | + fields[param_name] = (annotation, default) |
| 122 | + |
| 123 | + if not fields: |
| 124 | + return None |
| 125 | + |
| 126 | + model_name = "".join(part.capitalize() for part in name.replace("-", "_").split("_")) |
| 127 | + model_name = f"{model_name or 'Request'}Request" |
| 128 | + request_model = create_model(model_name, **fields) # type: ignore[arg-type] |
| 129 | + return request_model |
| 130 | + |
| 131 | + |
102 | 132 | async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]: |
103 | 133 | """Locate records matching an identifier or title.""" |
104 | 134 | # First, try direct ID lookup |
@@ -522,15 +552,50 @@ async def rest_docs(request: Request) -> Response: |
522 | 552 | def _build_openapi_schema() -> dict[str, Any]: |
523 | 553 | app = FastAPI() |
524 | 554 | for name, tool in server._tool_manager._tools.items(): |
525 | | - app.post(f"/rest/{name}")(tool.fn) |
| 555 | + request_model = _request_model(name, tool.fn) |
| 556 | + |
| 557 | + if request_model is None: |
| 558 | + app.post(f"/rest/{name}")(tool.fn) |
| 559 | + continue |
| 560 | + |
| 561 | + async def _tool_stub(payload: request_model) -> None: # type: ignore[name-defined] |
| 562 | + pass |
| 563 | + |
| 564 | + _tool_stub.__name__ = f"tool_{name.replace('-', '_')}" |
| 565 | + _tool_stub.__doc__ = tool.fn.__doc__ |
| 566 | + _tool_stub.__signature__ = inspect.Signature( |
| 567 | + parameters=[ |
| 568 | + inspect.Parameter( |
| 569 | + "payload", |
| 570 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 571 | + annotation=request_model, |
| 572 | + ) |
| 573 | + ], |
| 574 | + return_annotation=Any, |
| 575 | + ) |
| 576 | + |
| 577 | + app.post(f"/rest/{name}")(_tool_stub) |
526 | 578 | for name, prompt in server._prompt_manager._prompts.items(): |
527 | 579 | async def _p_stub(**kwargs): # noqa: ARG001 |
528 | 580 | pass |
529 | 581 | _p_stub.__name__ = f"prompt_{name.replace('-', '_')}" |
530 | 582 | _p_stub.__doc__ = prompt.fn.__doc__ |
531 | | - _p_stub.__signature__ = inspect.signature(prompt.fn).replace( |
532 | | - return_annotation=Any |
533 | | - ) |
| 583 | + request_model = _request_model(name, prompt.fn) |
| 584 | + if request_model is None: |
| 585 | + _p_stub.__signature__ = inspect.signature(prompt.fn).replace( |
| 586 | + return_annotation=Any |
| 587 | + ) |
| 588 | + else: |
| 589 | + _p_stub.__signature__ = inspect.Signature( |
| 590 | + parameters=[ |
| 591 | + inspect.Parameter( |
| 592 | + "payload", |
| 593 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 594 | + annotation=request_model, |
| 595 | + ) |
| 596 | + ], |
| 597 | + return_annotation=Any, |
| 598 | + ) |
534 | 599 | app.post(f"/rest/prompt/{name}")(_p_stub) |
535 | 600 | for uri, resource in server._resource_manager._templates.items(): |
536 | 601 | path = uri.replace("resource://", "") |
|
0 commit comments