Skip to content

Commit 777fbe7

Browse files
committed
Use get_type_hints() in function_schema()
1 parent f6d1152 commit 777fbe7

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from collections.abc import Awaitable, Callable
99
from dataclasses import dataclass, field
10+
from functools import partial
1011
from inspect import Parameter, signature
1112
from typing import TYPE_CHECKING, Any, Concatenate, cast, get_origin
1213

@@ -17,7 +18,7 @@
1718
from pydantic.json_schema import GenerateJsonSchema
1819
from pydantic.plugin._schema_validator import create_schema_validator
1920
from pydantic_core import SchemaValidator, core_schema
20-
from typing_extensions import ParamSpec, TypeIs, TypeVar
21+
from typing_extensions import ParamSpec, TypeIs, TypeVar, get_type_hints
2122

2223
from ._griffe import doc_descriptions
2324
from ._run_context import RunContext
@@ -90,9 +91,6 @@ def function_schema( # noqa: C901
9091
Returns:
9192
A `FunctionSchema` instance.
9293
"""
93-
if takes_ctx is None:
94-
takes_ctx = _takes_ctx(function)
95-
9694
config = ConfigDict(title=function.__name__, use_attribute_docstrings=True)
9795
config_wrapper = ConfigWrapper(config)
9896
gen_schema = _generate_schema.GenerateSchema(config_wrapper)
@@ -103,16 +101,17 @@ def function_schema( # noqa: C901
103101
except ValueError as e:
104102
errors.append(str(e))
105103
sig = signature(lambda: None)
104+
original_func = function.func if isinstance(function, partial) else function
106105

107-
type_hints = _typing_extra.get_function_type_hints(function)
106+
type_hints = get_type_hints(original_func)
108107

109108
var_kwargs_schema: core_schema.CoreSchema | None = None
110109
fields: dict[str, core_schema.TypedDictField] = {}
111110
positional_fields: list[str] = []
112111
var_positional_field: str | None = None
113112
decorators = _decorators.DecoratorInfos()
114113

115-
description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)
114+
description, field_descriptions = doc_descriptions(original_func, sig, docstring_format=docstring_format)
116115

117116
if require_parameter_descriptions:
118117
if takes_ctx:
@@ -127,6 +126,9 @@ def function_schema( # noqa: C901
127126
errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}')
128127

129128
for index, (name, p) in enumerate(sig.parameters.items()):
129+
if index == 0 and takes_ctx is None:
130+
takes_ctx = p.annotation is not sig.empty and _is_call_ctx(type_hints[name])
131+
130132
if p.annotation is sig.empty:
131133
if takes_ctx and index == 0:
132134
# should be the `context` argument, skip
@@ -234,7 +236,7 @@ def function_schema( # noqa: C901
234236
TargetCallable = WithCtx[P, R] | WithoutCtx[P, R]
235237

236238

237-
def _takes_ctx(callable_obj: TargetCallable[P, R]) -> TypeIs[WithCtx[P, R]]:
239+
def _takes_ctx(callable_obj: TargetCallable[P, R]) -> TypeIs[WithCtx[P, R]]: # pyright: ignore[reportUnusedFunction]
238240
"""Check if a callable takes a `RunContext` first argument.
239241
240242
Args:

0 commit comments

Comments
 (0)