Skip to content

Commit 13fd648

Browse files
committed
Use get_type_hints() in function_schema()
1 parent f6d1152 commit 13fd648

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from collections.abc import Awaitable, Callable
99
from dataclasses import dataclass, field
10+
from functools import partial
1011
from inspect import Parameter, signature
1112
from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin
1213

@@ -17,7 +18,7 @@
1718
from pydantic.json_schema import GenerateJsonSchema
1819
from pydantic.plugin._schema_validator import create_schema_validator
1920
from pydantic_core import SchemaValidator, core_schema
20-
from typing_extensions import ParamSpec, TypeIs, TypeVar
21+
from typing_extensions import ParamSpec, TypeIs, TypeVar, get_type_hints
2122

2223
from ._griffe import doc_descriptions
2324
from ._run_context import RunContext
@@ -90,9 +91,6 @@ def function_schema( # noqa: C901
9091
Returns:
9192
A `FunctionSchema` instance.
9293
"""
93-
if takes_ctx is None:
94-
takes_ctx = _takes_ctx(function)
95-
9694
config = ConfigDict(title=function.__name__, use_attribute_docstrings=True)
9795
config_wrapper = ConfigWrapper(config)
9896
gen_schema = _generate_schema.GenerateSchema(config_wrapper)
@@ -103,30 +101,24 @@ def function_schema( # noqa: C901
103101
except ValueError as e:
104102
errors.append(str(e))
105103
sig = signature(lambda: None)
104+
original_func = function.func if isinstance(function, partial) else function
105+
function = cast(Callable[..., Any], function) # cope with pyright changing the type from the isinstance() check.
106106

107-
type_hints = _typing_extra.get_function_type_hints(function)
107+
type_hints = get_type_hints(original_func)
108108

109109
var_kwargs_schema: core_schema.CoreSchema | None = None
110110
fields: dict[str, core_schema.TypedDictField] = {}
111111
positional_fields: list[str] = []
112112
var_positional_field: str | None = None
113113
decorators = _decorators.DecoratorInfos()
114114

115-
description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)
116-
117-
if require_parameter_descriptions:
118-
if takes_ctx:
119-
parameters_without_ctx = set(
120-
name for name in sig.parameters if not _is_call_ctx(sig.parameters[name].annotation)
121-
)
122-
missing_params = parameters_without_ctx - set(field_descriptions)
123-
else:
124-
missing_params = set(sig.parameters) - set(field_descriptions)
125-
126-
if missing_params:
127-
errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}')
115+
description, field_descriptions = doc_descriptions(original_func, sig, docstring_format=docstring_format)
116+
missing_param_descriptions: set[str] = set()
128117

129118
for index, (name, p) in enumerate(sig.parameters.items()):
119+
if index == 0 and takes_ctx is None:
120+
takes_ctx = p.annotation is not sig.empty and _is_call_ctx(type_hints[name])
121+
130122
if p.annotation is sig.empty:
131123
if takes_ctx and index == 0:
132124
# should be the `context` argument, skip
@@ -148,6 +140,10 @@ def function_schema( # noqa: C901
148140
continue
149141

150142
field_name = p.name
143+
144+
if require_parameter_descriptions and field_name not in field_descriptions:
145+
missing_param_descriptions.add(field_name)
146+
151147
if p.kind == Parameter.VAR_KEYWORD:
152148
var_kwargs_schema = gen_schema.generate_schema(annotation)
153149
else:
@@ -178,6 +174,9 @@ def function_schema( # noqa: C901
178174
elif p.kind == Parameter.VAR_POSITIONAL:
179175
var_positional_field = field_name
180176

177+
if missing_param_descriptions:
178+
errors.append(f'Missing parameter descriptions for {", ".join(missing_param_descriptions)}')
179+
181180
if errors:
182181
from .exceptions import UserError
183182

@@ -219,7 +218,7 @@ def function_schema( # noqa: C901
219218
single_arg_name=single_arg_name,
220219
positional_fields=positional_fields,
221220
var_positional_field=var_positional_field,
222-
takes_ctx=takes_ctx,
221+
takes_ctx=bool(takes_ctx),
223222
is_async=is_async_callable(function),
224223
function=function,
225224
)
@@ -234,7 +233,7 @@ def function_schema( # noqa: C901
234233
TargetCallable = WithCtx[P, R] | WithoutCtx[P, R]
235234

236235

237-
def _takes_ctx(callable_obj: TargetCallable[P, R]) -> TypeIs[WithCtx[P, R]]:
236+
def _takes_ctx(callable_obj: TargetCallable[P, R]) -> TypeIs[WithCtx[P, R]]: # pyright: ignore[reportUnusedFunction]
238237
"""Check if a callable takes a `RunContext` first argument.
239238
240239
Args:

0 commit comments

Comments
 (0)