diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index 2b8270f322..9872b7383b 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass, field +from functools import partial from inspect import Parameter, signature from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin @@ -17,7 +18,7 @@ from pydantic.json_schema import GenerateJsonSchema from pydantic.plugin._schema_validator import create_schema_validator from pydantic_core import SchemaValidator, core_schema -from typing_extensions import ParamSpec, TypeIs, TypeVar +from typing_extensions import ParamSpec, TypeIs, TypeVar, get_type_hints from ._griffe import doc_descriptions from ._run_context import RunContext @@ -90,9 +91,6 @@ def function_schema( # noqa: C901 Returns: A `FunctionSchema` instance. """ - if takes_ctx is None: - takes_ctx = _takes_ctx(function) - config = ConfigDict(title=function.__name__, use_attribute_docstrings=True) config_wrapper = ConfigWrapper(config) gen_schema = _generate_schema.GenerateSchema(config_wrapper) @@ -103,8 +101,10 @@ def function_schema( # noqa: C901 except ValueError as e: errors.append(str(e)) sig = signature(lambda: None) + original_func = function.func if isinstance(function, partial) else function + function = cast(Callable[..., Any], function) # cope with pyright changing the type from the isinstance() check. - type_hints = _typing_extra.get_function_type_hints(function) + type_hints = get_type_hints(original_func) var_kwargs_schema: core_schema.CoreSchema | None = None fields: dict[str, core_schema.TypedDictField] = {} @@ -112,21 +112,13 @@ def function_schema( # noqa: C901 var_positional_field: str | None = None decorators = _decorators.DecoratorInfos() - description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format) - - if require_parameter_descriptions: - if takes_ctx: - parameters_without_ctx = set( - name for name in sig.parameters if not _is_call_ctx(sig.parameters[name].annotation) - ) - missing_params = parameters_without_ctx - set(field_descriptions) - else: - missing_params = set(sig.parameters) - set(field_descriptions) - - if missing_params: - errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}') + description, field_descriptions = doc_descriptions(original_func, sig, docstring_format=docstring_format) + missing_param_descriptions: set[str] = set() for index, (name, p) in enumerate(sig.parameters.items()): + if index == 0 and takes_ctx is None: + takes_ctx = p.annotation is not sig.empty and _is_call_ctx(type_hints[name]) + if p.annotation is sig.empty: if takes_ctx and index == 0: # should be the `context` argument, skip @@ -148,6 +140,10 @@ def function_schema( # noqa: C901 continue field_name = p.name + + if require_parameter_descriptions and field_name not in field_descriptions: + missing_param_descriptions.add(field_name) + if p.kind == Parameter.VAR_KEYWORD: var_kwargs_schema = gen_schema.generate_schema(annotation) else: @@ -178,6 +174,9 @@ def function_schema( # noqa: C901 elif p.kind == Parameter.VAR_POSITIONAL: var_positional_field = field_name + if missing_param_descriptions: + errors.append(f'Missing parameter descriptions for {", ".join(missing_param_descriptions)}') + if errors: from .exceptions import UserError @@ -219,7 +218,7 @@ def function_schema( # noqa: C901 single_arg_name=single_arg_name, positional_fields=positional_fields, var_positional_field=var_positional_field, - takes_ctx=takes_ctx, + takes_ctx=bool(takes_ctx), is_async=is_async_callable(function), function=function, ) @@ -234,7 +233,7 @@ def function_schema( # noqa: C901 TargetCallable = WithCtx[P, R] | WithoutCtx[P, R] -def _takes_ctx(callable_obj: TargetCallable[P, R]) -> TypeIs[WithCtx[P, R]]: +def _takes_ctx(callable_obj: TargetCallable[P, R]) -> TypeIs[WithCtx[P, R]]: # pyright: ignore[reportUnusedFunction] """Check if a callable takes a `RunContext` first argument. Args: