77
88from collections .abc import Awaitable , Callable
99from dataclasses import dataclass , field
10+ from functools import partial
1011from inspect import Parameter , signature
1112from typing import TYPE_CHECKING , Any , Concatenate , cast , get_origin
1213
1718from pydantic .json_schema import GenerateJsonSchema
1819from pydantic .plugin ._schema_validator import create_schema_validator
1920from pydantic_core import SchemaValidator , core_schema
20- from typing_extensions import ParamSpec , TypeIs , TypeVar
21+ from typing_extensions import ParamSpec , TypeIs , TypeVar , get_type_hints
2122
2223from ._griffe import doc_descriptions
2324from ._run_context import RunContext
@@ -90,9 +91,6 @@ def function_schema( # noqa: C901
9091 Returns:
9192 A `FunctionSchema` instance.
9293 """
93- if takes_ctx is None :
94- takes_ctx = _takes_ctx (function )
95-
9694 config = ConfigDict (title = function .__name__ , use_attribute_docstrings = True )
9795 config_wrapper = ConfigWrapper (config )
9896 gen_schema = _generate_schema .GenerateSchema (config_wrapper )
@@ -103,30 +101,24 @@ def function_schema( # noqa: C901
103101 except ValueError as e :
104102 errors .append (str (e ))
105103 sig = signature (lambda : None )
104+ original_func = function .func if isinstance (function , partial ) else function
105+ function = cast (Callable [..., Any ], function ) # cope with pyright changing the type from the isinstance() check.
106106
107- type_hints = _typing_extra . get_function_type_hints ( function )
107+ type_hints = get_type_hints ( original_func )
108108
109109 var_kwargs_schema : core_schema .CoreSchema | None = None
110110 fields : dict [str , core_schema .TypedDictField ] = {}
111111 positional_fields : list [str ] = []
112112 var_positional_field : str | None = None
113113 decorators = _decorators .DecoratorInfos ()
114114
115- description , field_descriptions = doc_descriptions (function , sig , docstring_format = docstring_format )
116-
117- if require_parameter_descriptions :
118- if takes_ctx :
119- parameters_without_ctx = set (
120- name for name in sig .parameters if not _is_call_ctx (sig .parameters [name ].annotation )
121- )
122- missing_params = parameters_without_ctx - set (field_descriptions )
123- else :
124- missing_params = set (sig .parameters ) - set (field_descriptions )
125-
126- if missing_params :
127- errors .append (f'Missing parameter descriptions for { ", " .join (missing_params )} ' )
115+ description , field_descriptions = doc_descriptions (original_func , sig , docstring_format = docstring_format )
116+ missing_param_descriptions : set [str ] = set ()
128117
129118 for index , (name , p ) in enumerate (sig .parameters .items ()):
119+ if index == 0 and takes_ctx is None :
120+ takes_ctx = p .annotation is not sig .empty and _is_call_ctx (type_hints [name ])
121+
130122 if p .annotation is sig .empty :
131123 if takes_ctx and index == 0 :
132124 # should be the `context` argument, skip
@@ -148,6 +140,10 @@ def function_schema( # noqa: C901
148140 continue
149141
150142 field_name = p .name
143+
144+ if require_parameter_descriptions and field_name not in field_descriptions :
145+ missing_param_descriptions .add (field_name )
146+
151147 if p .kind == Parameter .VAR_KEYWORD :
152148 var_kwargs_schema = gen_schema .generate_schema (annotation )
153149 else :
@@ -178,6 +174,9 @@ def function_schema( # noqa: C901
178174 elif p .kind == Parameter .VAR_POSITIONAL :
179175 var_positional_field = field_name
180176
177+ if missing_param_descriptions :
178+ errors .append (f'Missing parameter descriptions for { ", " .join (missing_param_descriptions )} ' )
179+
181180 if errors :
182181 from .exceptions import UserError
183182
@@ -219,7 +218,7 @@ def function_schema( # noqa: C901
219218 single_arg_name = single_arg_name ,
220219 positional_fields = positional_fields ,
221220 var_positional_field = var_positional_field ,
222- takes_ctx = takes_ctx ,
221+ takes_ctx = bool ( takes_ctx ) ,
223222 is_async = is_async_callable (function ),
224223 function = function ,
225224 )
@@ -234,7 +233,7 @@ def function_schema( # noqa: C901
234233TargetCallable = WithCtx [P , R ] | WithoutCtx [P , R ]
235234
236235
237- def _takes_ctx (callable_obj : TargetCallable [P , R ]) -> TypeIs [WithCtx [P , R ]]:
236+ def _takes_ctx (callable_obj : TargetCallable [P , R ]) -> TypeIs [WithCtx [P , R ]]: # pyright: ignore[reportUnusedFunction]
238237 """Check if a callable takes a `RunContext` first argument.
239238
240239 Args:
0 commit comments