diff --git a/tests/test_type_dir.py b/tests/test_type_dir.py index 08da787..0ef7ff9 100644 --- a/tests/test_type_dir.py +++ b/tests/test_type_dir.py @@ -206,7 +206,7 @@ class Final: def foo(self: Self, a: int | None, *, b: int = ...) -> dict[str, int]: ... def base[Z](self: Self, a: int | Z | None, b: ~K) -> dict[str, int | Z]: ... @classmethod - def cbase(cls: type[tests.test_type_dir.Base[int]], a: int | None, b: ~K) -> dict[str, int]: ... + def cbase(cls: type[typing.Self], a: int | None, b: ~K) -> dict[str, int]: ... @staticmethod def sbase[Z](a: int | Literal['gotcha!'] | Z | None, b: ~K) -> dict[str, int | Z]: ... """) @@ -218,7 +218,7 @@ def test_type_dir_1b(): assert format_helper.format_class(d) == textwrap.dedent("""\ class CMethod: @classmethod - def cbase2(cls: type[tests.test_type_dir.CMethod], lol: int, /, a: bool | None) -> int: ... + def cbase2(cls: type[typing.Self], lol: int, /, a: bool | None) -> int: ... """) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index f7989ed..30beb31 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -535,17 +535,17 @@ class C: def f(cls, x: int, /, y: int, *, z: int) -> int: ... f = eval_typing(GetMethodLike[IndirectProtocol[C], Literal["f"]]) - t = eval_typing(GetArg[f, Callable, Literal[0]]) + t = eval_typing(GetArg[f, classmethod, Literal[0]]) + t = eval_typing(GetArg[f, classmethod, Literal[1]]) assert ( t == tuple[ - Param[Literal["cls"], type[C], Literal["positional"]], Param[Literal["x"], int, Literal["positional"]], Param[Literal["y"], int], Param[Literal["z"], int, Literal["keyword"]], ] ) - t = eval_typing(GetArg[f, Callable, Literal[1]]) + t = eval_typing(GetArg[f, classmethod, Literal[2]]) assert t is int @@ -574,7 +574,7 @@ class C: def f(x: int, /, y: int, *, z: int) -> int: ... f = eval_typing(GetMethodLike[IndirectProtocol[C], Literal["f"]]) - t = eval_typing(GetArg[f, Callable, Literal[0]]) + t = eval_typing(GetArg[f, staticmethod, Literal[0]]) assert ( t == tuple[ @@ -583,7 +583,7 @@ def f(x: int, /, y: int, *, z: int) -> int: ... Param[Literal["z"], int, Literal["keyword"]], ] ) - t = eval_typing(GetArg[f, Callable, Literal[1]]) + t = eval_typing(GetArg[f, staticmethod, Literal[1]]) assert t is int @@ -957,9 +957,8 @@ def test_is_literal_true_vs_one(): assert eval_typing(IsSub[Literal[True], Literal[1]]) is False -def test_callable_to_signature(): +def test_callable_to_signature_01(): from typemap.type_eval._eval_operators import _callable_type_to_signature - from typemap.typing import Param # Test the example from the docstring: # def func( @@ -996,6 +995,111 @@ def test_callable_to_signature(): ) +def test_callable_to_signature_02(): + from typemap.type_eval._eval_operators import _callable_type_to_signature + + class C: + pass + + callable_type = classmethod[ + C, + tuple[ + Param[None, int], + Param[Literal["b"], int], + Param[Literal["c"], int, Literal["default"]], + Param[None, int, Literal["*"]], + Param[Literal["d"], int, Literal["keyword"]], + Param[Literal["e"], int, Literal["default", "keyword"]], + Param[None, int, Literal["**"]], + ], + int, + ] + sig = _callable_type_to_signature(callable_type) + assert str(sig) == ( + '(cls: tests.test_type_eval.test_callable_to_signature_02..C, ' + '_arg1: int, /, b: int, c: int = ..., *args: int, ' + 'd: int, e: int = ..., **kwargs: int) -> int' + ) + + +def test_callable_to_signature_03(): + from typemap.type_eval._eval_operators import _callable_type_to_signature + + class C: + pass + + callable_type = staticmethod[ + tuple[ + Param[None, int], + Param[Literal["b"], int], + Param[Literal["c"], int, Literal["default"]], + Param[None, int, Literal["*"]], + Param[Literal["d"], int, Literal["keyword"]], + Param[Literal["e"], int, Literal["default", "keyword"]], + Param[None, int, Literal["**"]], + ], + int, + ] + sig = _callable_type_to_signature(callable_type) + assert str(sig) == ( + '(_arg0: int, /, b: int, c: int = ..., *args: int, ' + 'd: int, e: int = ..., **kwargs: int) -> int' + ) + + +def test_new_protocol_with_methods_01(): + class C: + def member_method(self, x: int) -> int: ... + @classmethod + def class_method(cls, x: int) -> int: ... + @staticmethod + def static_method(x: int) -> int: ... + + res = eval_typing(IndirectProtocol[C]) + fmt = format_helper.format_class(res) + assert fmt == textwrap.dedent("""\ + class IndirectProtocol[tests.test_type_eval.test_new_protocol_with_methods_01..C]: + def member_method(self: Self, x: int) -> int: ... + @classmethod + def class_method(cls: type[typing.Self], x: int) -> int: ... + @staticmethod + def static_method(x: int) -> int: ... + """) + + +def test_new_protocol_with_methods_02(): + C = NewProtocol[ + Member[ + Literal["member_method"], + Callable[ + [Param[Literal["self"], Self], Param[Literal["x"], int]], int + ], + Literal["ClassVar"], + ], + Member[ + Literal["class_method"], + classmethod[type[Self], tuple[Param[Literal["x"], int]], int], + Literal["ClassVar"], + ], + Member[ + Literal["static_method"], + staticmethod[tuple[Param[Literal["x"], int]], int], + Literal["ClassVar"], + ], + ] + + res = eval_typing(IndirectProtocol[C]) + fmt = format_helper.format_class(res) + assert fmt == textwrap.dedent("""\ + class IndirectProtocol[typemap.type_eval._eval_operators.NewProtocol]: + def member_method(self: Self, x: int) -> int: ... + @classmethod + def class_method(cls: type[typing.Self], x: int) -> int: ... + @staticmethod + def static_method(x: int) -> int: ... + """) + + ############## type XTest[X] = Annotated[X, 'blah'] diff --git a/typemap/type_eval/_apply_generic.py b/typemap/type_eval/_apply_generic.py index 2f2eb4d..15f6f09 100644 --- a/typemap/type_eval/_apply_generic.py +++ b/typemap/type_eval/_apply_generic.py @@ -250,6 +250,10 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]: stuff = inspect.unwrap(orig) if isinstance(stuff, types.FunctionType): + local_fn: types.FunctionType | classmethod | staticmethod | None = ( + None + ) + if af := typing.cast( types.FunctionType, getattr(stuff, "__annotate__", None) ): @@ -280,9 +284,17 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]: ) rr = ff(annotationlib.Format.VALUE) - dct[name] = make_func(orig, rr) + local_fn = make_func(orig, rr) elif af := getattr(stuff, "__annotations__", None): - dct[name] = stuff + local_fn = stuff + + if local_fn is not None: + if orig.__class__ is classmethod: + local_fn = classmethod(local_fn) + elif orig.__class__ is staticmethod: + local_fn = staticmethod(local_fn) + + dct[name] = local_fn return annos, dct diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 3842b2d..cb6d2db 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -280,10 +280,46 @@ def _callable_type_to_signature(callable_type: object) -> inspect.Signature: or Never if no qualifiers """ args = typing.get_args(callable_type) - if len(args) != 2: - raise TypeError(f"Expected Callable[[...], ret], got {callable_type}") + if ( + isinstance(callable_type, types.GenericAlias) + and callable_type.__origin__ is classmethod + ): + if len(args) != 3: + raise TypeError( + f"Expected classmethod[cls, [...], ret], got {callable_type}" + ) + + receiver, param_types, return_type = typing.get_args(callable_type) + param_types = [ + Param[ + typing.Literal["cls"], + receiver, # type: ignore[valid-type] + typing.Literal["positional"], + ], + *param_types.__args__, + ] + + elif ( + isinstance(callable_type, types.GenericAlias) + and callable_type.__origin__ is staticmethod + ): + if len(args) != 2: + raise TypeError( + f"Expected staticmethod[...], ret], got {callable_type}" + ) + + param_types, return_type = typing.get_args(callable_type) + param_types = [ + *param_types.__args__, + ] + + else: + if len(args) != 2: + raise TypeError( + f"Expected Callable[[...], ret], got {callable_type}" + ) - param_types, return_type = args + param_types, return_type = args # Handle the case where param_types is a list of Param types if not isinstance(param_types, (list, tuple)): @@ -421,7 +457,11 @@ def _callable_type_to_method(name, typ): # positional only argument. Annoying! has_pos_only = any(_is_pos_only(p) for p in typing.get_args(params)) quals = typing.Literal["positional"] if has_pos_only else typing.Never - cls_param = Param[typing.Literal["cls"], type[cls], quals] + # Override the receiver type with type[Self]. + # An annoying thing to know is that for a member classmethod of C, + # cls *should* be type[C], but if it was not explicitly annotated, it + # will be C. + cls_param = Param[typing.Literal["cls"], type[typing.Self], quals] typ = typing.Callable[[cls_param] + list(typing.get_args(params)), ret] elif head is staticmethod: params, ret = typing.get_args(typ)