77
88from collections .abc import Awaitable , Callable
99from dataclasses import dataclass , field
10+ from functools import partial
1011from inspect import Parameter , signature
1112from typing import TYPE_CHECKING , Any , Concatenate , cast , get_origin
1213
1718from pydantic .json_schema import GenerateJsonSchema
1819from pydantic .plugin ._schema_validator import create_schema_validator
1920from pydantic_core import SchemaValidator , core_schema
20- from typing_extensions import ParamSpec , TypeIs , TypeVar
21+ from typing_extensions import ParamSpec , TypeIs , TypeVar , get_type_hints
2122
2223from ._griffe import doc_descriptions
2324from ._run_context import RunContext
@@ -90,9 +91,6 @@ def function_schema( # noqa: C901
9091 Returns:
9192 A `FunctionSchema` instance.
9293 """
93- if takes_ctx is None :
94- takes_ctx = _takes_ctx (function )
95-
9694 config = ConfigDict (title = function .__name__ , use_attribute_docstrings = True )
9795 config_wrapper = ConfigWrapper (config )
9896 gen_schema = _generate_schema .GenerateSchema (config_wrapper )
@@ -103,16 +101,17 @@ def function_schema( # noqa: C901
103101 except ValueError as e :
104102 errors .append (str (e ))
105103 sig = signature (lambda : None )
104+ original_func = function .func if isinstance (function , partial ) else function
106105
107- type_hints = _typing_extra . get_function_type_hints ( function )
106+ type_hints = get_type_hints ( original_func )
108107
109108 var_kwargs_schema : core_schema .CoreSchema | None = None
110109 fields : dict [str , core_schema .TypedDictField ] = {}
111110 positional_fields : list [str ] = []
112111 var_positional_field : str | None = None
113112 decorators = _decorators .DecoratorInfos ()
114113
115- description , field_descriptions = doc_descriptions (function , sig , docstring_format = docstring_format )
114+ description , field_descriptions = doc_descriptions (original_func , sig , docstring_format = docstring_format )
116115
117116 if require_parameter_descriptions :
118117 if takes_ctx :
@@ -127,6 +126,9 @@ def function_schema( # noqa: C901
127126 errors .append (f'Missing parameter descriptions for { ", " .join (missing_params )} ' )
128127
129128 for index , (name , p ) in enumerate (sig .parameters .items ()):
129+ if index == 0 and takes_ctx is None :
130+ takes_ctx = p .annotation is not sig .empty and _is_call_ctx (type_hints [name ])
131+
130132 if p .annotation is sig .empty :
131133 if takes_ctx and index == 0 :
132134 # should be the `context` argument, skip
@@ -234,7 +236,7 @@ def function_schema( # noqa: C901
234236TargetCallable = WithCtx [P , R ] | WithoutCtx [P , R ]
235237
236238
237- def _takes_ctx (callable_obj : TargetCallable [P , R ]) -> TypeIs [WithCtx [P , R ]]:
239+ def _takes_ctx (callable_obj : TargetCallable [P , R ]) -> TypeIs [WithCtx [P , R ]]: # pyright: ignore[reportUnusedFunction]
238240 """Check if a callable takes a `RunContext` first argument.
239241
240242 Args:
0 commit comments