Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions pydantic_ai_slim/pydantic_ai/_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from functools import partial
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin

Expand All @@ -17,7 +18,7 @@
from pydantic.json_schema import GenerateJsonSchema
from pydantic.plugin._schema_validator import create_schema_validator
from pydantic_core import SchemaValidator, core_schema
from typing_extensions import ParamSpec, TypeIs, TypeVar
from typing_extensions import ParamSpec, TypeIs, TypeVar, get_type_hints

from ._griffe import doc_descriptions
from ._run_context import RunContext
Expand Down Expand Up @@ -90,9 +91,6 @@ def function_schema( # noqa: C901
Returns:
A `FunctionSchema` instance.
"""
if takes_ctx is None:
takes_ctx = _takes_ctx(function)

config = ConfigDict(title=function.__name__, use_attribute_docstrings=True)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper)
Expand All @@ -103,30 +101,24 @@ def function_schema( # noqa: C901
except ValueError as e:
errors.append(str(e))
sig = signature(lambda: None)
original_func = function.func if isinstance(function, partial) else function
function = cast(Callable[..., Any], function) # cope with pyright changing the type from the isinstance() check.

type_hints = _typing_extra.get_function_type_hints(function)
type_hints = get_type_hints(original_func)

var_kwargs_schema: core_schema.CoreSchema | None = None
fields: dict[str, core_schema.TypedDictField] = {}
positional_fields: list[str] = []
var_positional_field: str | None = None
decorators = _decorators.DecoratorInfos()

description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)

if require_parameter_descriptions:
if takes_ctx:
parameters_without_ctx = set(
name for name in sig.parameters if not _is_call_ctx(sig.parameters[name].annotation)
)
missing_params = parameters_without_ctx - set(field_descriptions)
else:
missing_params = set(sig.parameters) - set(field_descriptions)

if missing_params:
errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}')
description, field_descriptions = doc_descriptions(original_func, sig, docstring_format=docstring_format)
missing_param_descriptions: set[str] = set()

for index, (name, p) in enumerate(sig.parameters.items()):
if index == 0 and takes_ctx is None:
takes_ctx = p.annotation is not sig.empty and _is_call_ctx(type_hints[name])
Comment on lines +119 to +120
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using _takes_ctx(), do the logic here. This avoids computing the type hints twice.


if p.annotation is sig.empty:
if takes_ctx and index == 0:
# should be the `context` argument, skip
Expand All @@ -148,6 +140,10 @@ def function_schema( # noqa: C901
continue

field_name = p.name

if require_parameter_descriptions and field_name not in field_descriptions:
missing_param_descriptions.add(field_name)

if p.kind == Parameter.VAR_KEYWORD:
var_kwargs_schema = gen_schema.generate_schema(annotation)
else:
Expand Down Expand Up @@ -178,6 +174,9 @@ def function_schema( # noqa: C901
elif p.kind == Parameter.VAR_POSITIONAL:
var_positional_field = field_name

if missing_param_descriptions:
errors.append(f'Missing parameter descriptions for {", ".join(missing_param_descriptions)}')

if errors:
from .exceptions import UserError

Expand Down Expand Up @@ -219,7 +218,7 @@ def function_schema( # noqa: C901
single_arg_name=single_arg_name,
positional_fields=positional_fields,
var_positional_field=var_positional_field,
takes_ctx=takes_ctx,
takes_ctx=bool(takes_ctx),
is_async=is_async_callable(function),
function=function,
)
Expand All @@ -234,7 +233,7 @@ def function_schema( # noqa: C901
TargetCallable = WithCtx[P, R] | WithoutCtx[P, R]


def _takes_ctx(callable_obj: TargetCallable[P, R]) -> TypeIs[WithCtx[P, R]]:
def _takes_ctx(callable_obj: TargetCallable[P, R]) -> TypeIs[WithCtx[P, R]]: # pyright: ignore[reportUnusedFunction]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still used in one other place in Pydantic AI, will see how this can be cleaned up in a future PR.

"""Check if a callable takes a `RunContext` first argument.

Args:
Expand Down
Loading