Skip to content

Commit 4cbf96d

Browse files
committed
Add bool_special_form.
1 parent 7a0df62 commit 4cbf96d

File tree

4 files changed

+140
-5
lines changed

4 files changed

+140
-5
lines changed

tests/test_type_eval.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
StrConcat,
3434
StrSlice,
3535
Uppercase,
36+
bool_special_form,
3637
)
3738

3839
from . import format_helper
@@ -771,9 +772,19 @@ def test_callable_to_signature():
771772
)
772773

773774

774-
type IsNotInt[T] = not Is[T, int]
775-
type IsNotStr[T] = not Is[T, str]
776-
type IsNotIntOrStr[T] = IsNotInt[T] and IsNotStr[T]
775+
@bool_special_form
776+
class IsNotInt[T]:
777+
__expr__ = not Is[T, int]
778+
779+
780+
@bool_special_form
781+
class IsNotStr[T]:
782+
__expr__ = not Is[T, str]
783+
784+
785+
@bool_special_form
786+
class IsNotIntOrStr[T]:
787+
__expr__ = IsNotInt[T] and IsNotStr[T]
777788

778789

779790
type SetOfNotInt[T] = set[T] if IsNotInt[T] else T

typemap/type_eval/_eval_typing.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
from typing import Any
2020

2121
from . import _apply_generic
22-
from ._special_form import _special_form_evaluator
22+
from ._special_form import (
23+
BoolSpecialMetadata,
24+
_bool_special_form_registry,
25+
_special_form_evaluator,
26+
)
2327

2428

2529
__all__ = ("eval_typing",)
@@ -346,13 +350,56 @@ def _eval_applied_type_alias(obj: types.GenericAlias, ctx: EvalContext):
346350
return evaled
347351

348352

353+
def _eval_bool_special_form(
354+
metadata: BoolSpecialMetadata,
355+
new_args: tuple[typing.Any, ...],
356+
ctx: EvalContext,
357+
) -> bool:
358+
import ast
359+
360+
original_cls = metadata.cls
361+
362+
try:
363+
namespace = {}
364+
365+
# Add the class's module
366+
if cls_module := sys.modules.get(original_cls.__module__):
367+
namespace.update(cls_module.__dict__)
368+
369+
# Add type parameters with their substituted values
370+
type_params = metadata.type_params
371+
if type_params and new_args:
372+
for param, arg in zip(type_params, new_args, strict=False):
373+
namespace[param.__name__] = arg
374+
375+
expr = compile(
376+
ast.Expression(body=metadata.expr_node), # type: ignore[arg-type]
377+
'<bool_expr>',
378+
'eval',
379+
)
380+
bool_expr = eval(expr, namespace)
381+
382+
# Evaluate the type expression
383+
result = _eval_types(bool_expr, ctx)
384+
385+
return result
386+
387+
except Exception as e:
388+
raise RuntimeError(
389+
f"Failed to evaluate special form for {original_cls.__name__}: {e}"
390+
) from e
391+
392+
349393
@_eval_types_impl.register
350394
def _eval_applied_class(obj: typing_GenericAlias, ctx: EvalContext):
351395
"""Eval a typing._GenericAlias -- an applied user-defined class"""
352396
# generic *classes* are typing._GenericAlias while generic type
353397
# aliases are types.GenericAlias? Why in the world.
354398
new_args = tuple(_eval_types(arg, ctx) for arg in typing.get_args(obj))
355399

400+
if metadata := _bool_special_form_registry.get(obj.__origin__):
401+
return _eval_bool_special_form(metadata, new_args, ctx)
402+
356403
if func := _eval_funcs.get(obj.__origin__):
357404
ret = func(*new_args, ctx=ctx)
358405
# return _eval_types(ret, ctx) # ???

typemap/type_eval/_special_form.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import ast
12
import contextvars
3+
import dataclasses
24
import typing
35
from typing import _GenericAlias # type: ignore
46

@@ -21,10 +23,80 @@ def __iter__(self):
2123
return iter(typing.TypeVarTuple("_IterDummy"))
2224

2325

24-
class _IsGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
26+
class _BoolGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
2527
def __bool__(self):
2628
evaluator = _special_form_evaluator.get()
2729
if evaluator:
2830
return evaluator(self)
2931
else:
3032
return False
33+
34+
35+
_IsGenericAlias = _BoolGenericAlias
36+
37+
38+
_bool_special_form_registry: dict[typing.Any, BoolSpecialMetadata] = {}
39+
40+
41+
@dataclasses.dataclass(frozen=True, kw_only=True)
42+
class BoolSpecialMetadata:
43+
cls: type
44+
type_params: tuple[type]
45+
expr_node: ast.AST
46+
47+
48+
def _register_bool_special_form(cls):
49+
import inspect
50+
import textwrap
51+
52+
type_params = getattr(cls, '__type_params__', ())
53+
54+
if '__expr__' not in cls.__dict__:
55+
raise TypeError(f"{cls.__name__} must have an '__expr__' field")
56+
57+
# Parse __expr__ to get the assigned expression
58+
source = inspect.getsource(cls)
59+
source = textwrap.dedent(source)
60+
tree = ast.parse(source)
61+
62+
expr_node = None
63+
for node in ast.walk(tree):
64+
if isinstance(node, ast.ClassDef):
65+
for item in node.body:
66+
if isinstance(item, ast.AnnAssign):
67+
# __expr__: SomeType = expression
68+
if (
69+
isinstance(item.target, ast.Name)
70+
and item.target.id == '__expr__'
71+
):
72+
expr_node = item.value
73+
break
74+
elif isinstance(item, ast.Assign):
75+
# __expr__ = expression
76+
for target in item.targets:
77+
if (
78+
isinstance(target, ast.Name)
79+
and target.id == '__expr__'
80+
):
81+
expr_node = item.value
82+
break
83+
if expr_node:
84+
break
85+
if expr_node:
86+
break
87+
88+
if expr_node is None:
89+
raise TypeError(f"Could not find __expr__ assignment in {cls.__name__}")
90+
91+
def impl_func(self, params):
92+
return _BoolGenericAlias(self, params)
93+
94+
sf = _SpecialForm(impl_func)
95+
96+
_bool_special_form_registry[sf] = BoolSpecialMetadata(
97+
cls=cls,
98+
type_params=type_params,
99+
expr_node=expr_node,
100+
)
101+
102+
return sf

typemap/typing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
_IterGenericAlias,
55
_IsGenericAlias,
66
_SpecialForm,
7+
_register_bool_special_form,
78
)
89

910
# Not type-level computation but related
@@ -132,3 +133,7 @@ def IsSubSimilar(self, tps):
132133

133134

134135
Is = IsSubSimilar
136+
137+
138+
def bool_special_form(cls):
139+
return _register_bool_special_form(cls)

0 commit comments

Comments
 (0)