Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_type_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class Final:
def foo(self: Self, a: int | None, *, b: int = ...) -> dict[str, int]: ...
def base[Z](self: Self, a: int | Z | None, b: ~K) -> dict[str, int | Z]: ...
@classmethod
def cbase(cls: type[tests.test_type_dir.Base[int]], a: int | None, b: ~K) -> dict[str, int]: ...
def cbase(cls: type[typing.Self], a: int | None, b: ~K) -> dict[str, int]: ...
@staticmethod
def sbase[Z](a: int | Literal['gotcha!'] | Z | None, b: ~K) -> dict[str, int | Z]: ...
""")
Expand All @@ -218,7 +218,7 @@ def test_type_dir_1b():
assert format_helper.format_class(d) == textwrap.dedent("""\
class CMethod:
@classmethod
def cbase2(cls: type[tests.test_type_dir.CMethod], lol: int, /, a: bool | None) -> int: ...
def cbase2(cls: type[typing.Self], lol: int, /, a: bool | None) -> int: ...
""")


Expand Down
118 changes: 111 additions & 7 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,17 +535,17 @@ class C:
def f(cls, x: int, /, y: int, *, z: int) -> int: ...

f = eval_typing(GetMethodLike[IndirectProtocol[C], Literal["f"]])
t = eval_typing(GetArg[f, Callable, Literal[0]])
t = eval_typing(GetArg[f, classmethod, Literal[0]])
t = eval_typing(GetArg[f, classmethod, Literal[1]])
assert (
t
== tuple[
Param[Literal["cls"], type[C], Literal["positional"]],
Param[Literal["x"], int, Literal["positional"]],
Param[Literal["y"], int],
Param[Literal["z"], int, Literal["keyword"]],
]
)
t = eval_typing(GetArg[f, Callable, Literal[1]])
t = eval_typing(GetArg[f, classmethod, Literal[2]])
assert t is int


Expand Down Expand Up @@ -574,7 +574,7 @@ class C:
def f(x: int, /, y: int, *, z: int) -> int: ...

f = eval_typing(GetMethodLike[IndirectProtocol[C], Literal["f"]])
t = eval_typing(GetArg[f, Callable, Literal[0]])
t = eval_typing(GetArg[f, staticmethod, Literal[0]])
assert (
t
== tuple[
Expand All @@ -583,7 +583,7 @@ def f(x: int, /, y: int, *, z: int) -> int: ...
Param[Literal["z"], int, Literal["keyword"]],
]
)
t = eval_typing(GetArg[f, Callable, Literal[1]])
t = eval_typing(GetArg[f, staticmethod, Literal[1]])
assert t is int


Expand Down Expand Up @@ -957,9 +957,8 @@ def test_is_literal_true_vs_one():
assert eval_typing(IsSub[Literal[True], Literal[1]]) is False


def test_callable_to_signature():
def test_callable_to_signature_01():
from typemap.type_eval._eval_operators import _callable_type_to_signature
from typemap.typing import Param

# Test the example from the docstring:
# def func(
Expand Down Expand Up @@ -996,6 +995,111 @@ def test_callable_to_signature():
)


def test_callable_to_signature_02():
from typemap.type_eval._eval_operators import _callable_type_to_signature

class C:
pass

callable_type = classmethod[
C,
tuple[
Param[None, int],
Param[Literal["b"], int],
Param[Literal["c"], int, Literal["default"]],
Param[None, int, Literal["*"]],
Param[Literal["d"], int, Literal["keyword"]],
Param[Literal["e"], int, Literal["default", "keyword"]],
Param[None, int, Literal["**"]],
],
int,
]
sig = _callable_type_to_signature(callable_type)
assert str(sig) == (
'(cls: tests.test_type_eval.test_callable_to_signature_02.<locals>.C, '
'_arg1: int, /, b: int, c: int = ..., *args: int, '
'd: int, e: int = ..., **kwargs: int) -> int'
)


def test_callable_to_signature_03():
from typemap.type_eval._eval_operators import _callable_type_to_signature

class C:
pass

callable_type = staticmethod[
tuple[
Param[None, int],
Param[Literal["b"], int],
Param[Literal["c"], int, Literal["default"]],
Param[None, int, Literal["*"]],
Param[Literal["d"], int, Literal["keyword"]],
Param[Literal["e"], int, Literal["default", "keyword"]],
Param[None, int, Literal["**"]],
],
int,
]
sig = _callable_type_to_signature(callable_type)
assert str(sig) == (
'(_arg0: int, /, b: int, c: int = ..., *args: int, '
'd: int, e: int = ..., **kwargs: int) -> int'
)


def test_new_protocol_with_methods_01():
class C:
def member_method(self, x: int) -> int: ...
@classmethod
def class_method(cls, x: int) -> int: ...
@staticmethod
def static_method(x: int) -> int: ...

res = eval_typing(IndirectProtocol[C])
fmt = format_helper.format_class(res)
assert fmt == textwrap.dedent("""\
class IndirectProtocol[tests.test_type_eval.test_new_protocol_with_methods_01.<locals>.C]:
def member_method(self: Self, x: int) -> int: ...
@classmethod
def class_method(cls: type[typing.Self], x: int) -> int: ...
@staticmethod
def static_method(x: int) -> int: ...
""")


def test_new_protocol_with_methods_02():
C = NewProtocol[
Member[
Literal["member_method"],
Callable[
[Param[Literal["self"], Self], Param[Literal["x"], int]], int
],
Literal["ClassVar"],
],
Member[
Literal["class_method"],
classmethod[type[Self], tuple[Param[Literal["x"], int]], int],
Literal["ClassVar"],
],
Member[
Literal["static_method"],
staticmethod[tuple[Param[Literal["x"], int]], int],
Literal["ClassVar"],
],
]

res = eval_typing(IndirectProtocol[C])
fmt = format_helper.format_class(res)
assert fmt == textwrap.dedent("""\
class IndirectProtocol[typemap.type_eval._eval_operators.NewProtocol]:
def member_method(self: Self, x: int) -> int: ...
@classmethod
def class_method(cls: type[typing.Self], x: int) -> int: ...
@staticmethod
def static_method(x: int) -> int: ...
""")


##############

type XTest[X] = Annotated[X, 'blah']
Expand Down
16 changes: 14 additions & 2 deletions typemap/type_eval/_apply_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
stuff = inspect.unwrap(orig)

if isinstance(stuff, types.FunctionType):
local_fn: types.FunctionType | classmethod | staticmethod | None = (
None
)

if af := typing.cast(
types.FunctionType, getattr(stuff, "__annotate__", None)
):
Expand Down Expand Up @@ -280,9 +284,17 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
)
rr = ff(annotationlib.Format.VALUE)

dct[name] = make_func(orig, rr)
local_fn = make_func(orig, rr)
elif af := getattr(stuff, "__annotations__", None):
dct[name] = stuff
local_fn = stuff

if local_fn is not None:
if orig.__class__ is classmethod:
local_fn = classmethod(local_fn)
elif orig.__class__ is staticmethod:
local_fn = staticmethod(local_fn)

dct[name] = local_fn

return annos, dct

Expand Down
48 changes: 44 additions & 4 deletions typemap/type_eval/_eval_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,46 @@ def _callable_type_to_signature(callable_type: object) -> inspect.Signature:
or Never if no qualifiers
"""
args = typing.get_args(callable_type)
if len(args) != 2:
raise TypeError(f"Expected Callable[[...], ret], got {callable_type}")
if (
isinstance(callable_type, types.GenericAlias)
and callable_type.__origin__ is classmethod
):
if len(args) != 3:
raise TypeError(
f"Expected classmethod[cls, [...], ret], got {callable_type}"
)

receiver, param_types, return_type = typing.get_args(callable_type)
param_types = [
Param[
typing.Literal["cls"],
receiver, # type: ignore[valid-type]
typing.Literal["positional"],
],
*param_types.__args__,
]

elif (
isinstance(callable_type, types.GenericAlias)
and callable_type.__origin__ is staticmethod
):
if len(args) != 2:
raise TypeError(
f"Expected staticmethod[...], ret], got {callable_type}"
)

param_types, return_type = typing.get_args(callable_type)
param_types = [
*param_types.__args__,
]

else:
if len(args) != 2:
raise TypeError(
f"Expected Callable[[...], ret], got {callable_type}"
)

param_types, return_type = args
param_types, return_type = args

# Handle the case where param_types is a list of Param types
if not isinstance(param_types, (list, tuple)):
Expand Down Expand Up @@ -421,7 +457,11 @@ def _callable_type_to_method(name, typ):
# positional only argument. Annoying!
has_pos_only = any(_is_pos_only(p) for p in typing.get_args(params))
quals = typing.Literal["positional"] if has_pos_only else typing.Never
cls_param = Param[typing.Literal["cls"], type[cls], quals]
# Override the receiver type with type[Self].
# An annoying thing to know is that for a member classmethod of C,
# cls *should* be type[C], but if it was not explicitly annotated, it
# will be C.
cls_param = Param[typing.Literal["cls"], type[typing.Self], quals]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want to override with typing.Self instead of using what was already there?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type[Self] makes the annotations simpler and makes things easier when combining protocols which may have different receivers. We can use GetDefiner if we really need the actual class.

typ = typing.Callable[[cls_param] + list(typing.get_args(params)), ret]
elif head is staticmethod:
params, ret = typing.get_args(typ)
Expand Down