diff --git a/mypyc/doc/native_operations.rst b/mypyc/doc/native_operations.rst index 3255dbedd98a..356a8930d656 100644 --- a/mypyc/doc/native_operations.rst +++ b/mypyc/doc/native_operations.rst @@ -54,3 +54,4 @@ These variants of statements have custom implementations: * ``for ... in seq:`` (for loop over a sequence) * ``for ... in enumerate(...):`` * ``for ... in zip(...):`` +* ``for ... in map(...)`` diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 762b41866a05..25fb148eddaf 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -11,6 +11,7 @@ from mypy.nodes import ( ARG_POS, + LDEF, CallExpr, DictionaryComprehension, Expression, @@ -22,6 +23,7 @@ SetExpr, TupleExpr, TypeAlias, + Var, ) from mypyc.ir.ops import ( ERR_NEVER, @@ -491,6 +493,16 @@ def make_for_loop_generator( for_list = ForSequence(builder, index, body_block, loop_exit, line, nested) for_list.init(expr_reg, target_type, reverse=True) return for_list + + elif ( + expr.callee.fullname == "builtins.map" + and len(expr.args) >= 2 + and all(k == ARG_POS for k in expr.arg_kinds) + ): + for_map = ForMap(builder, index, body_block, loop_exit, line, nested) + for_map.init(expr.args[0], expr.args[1:]) + return for_map + if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberExpr) and not expr.args: # Special cases for dictionary iterator methods, like dict.items(). rtype = builder.node_type(expr.callee.expr) @@ -1166,3 +1178,72 @@ def gen_step(self) -> None: def gen_cleanup(self) -> None: for gen in self.gens: gen.gen_cleanup() + + +class ForMap(ForGenerator): + """Generate optimized IR for a for loop over map(f, ...).""" + + def need_cleanup(self) -> bool: + # The wrapped for loops might need cleanup. We might generate a + # redundant cleanup block, but that's okay. + return True + + def init(self, func: Expression, exprs: list[Expression]) -> None: + self.func_expr = func + self.func = self.builder.accept(func) + self.exprs = exprs + self.cond_blocks = [BasicBlock() for _ in range(len(exprs) - 1)] + [self.body_block] + + self.gens: list[ForGenerator] = [] + for i, iterable_expr in enumerate(exprs): + argname = f"_mypyc_map_arg_{i}" + var_type = self.builder._analyze_iterable_item_type(iterable_expr) + name_expr = NameExpr(argname) + name_expr.kind = LDEF + name_expr.node = Var(argname, var_type) + self.builder.add_local_reg(name_expr.node, self.builder.type_to_rtype(var_type)) + self.gens.append( + make_for_loop_generator( + self.builder, + name_expr, + iterable_expr, + self.cond_blocks[i], + self.loop_exit, + self.line, + is_async=False, + nested=True, + ) + ) + + def gen_condition(self) -> None: + for i, gen in enumerate(self.gens): + gen.gen_condition() + if i < len(self.gens) - 1: + self.builder.activate_block(self.cond_blocks[i]) + + def begin_body(self) -> None: + builder = self.builder + + for gen in self.gens: + gen.begin_body() + + call_expr = CallExpr( + self.func_expr, + [gen.index for gen in self.gens], + [ARG_POS] * len(self.gens), + [None] * len(self.gens), + ) + + # TODO: debug redundant box->unbox op in builder.accept and then replace this + from mypyc.irbuild.expression import transform_call_expr + + result = transform_call_expr(builder, call_expr) + builder.assign(builder.get_assignment_target(self.index), result, self.line) + + def gen_step(self) -> None: + for gen in self.gens: + gen.gen_step() + + def gen_cleanup(self) -> None: + for gen in self.gens: + gen.gen_cleanup() diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 0880c62bc7a5..ea36d73466c2 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -75,6 +75,7 @@ from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.for_helpers import ( comprehension_helper, + for_loop_helper, sequence_from_generator_preallocate_helper, translate_list_comprehension, translate_set_comprehension, @@ -95,7 +96,7 @@ ) from mypyc.primitives.float_ops import isinstance_float from mypyc.primitives.int_ops import isinstance_int -from mypyc.primitives.list_ops import isinstance_list, new_list_set_item_op +from mypyc.primitives.list_ops import isinstance_list, list_append_op, new_list_set_item_op from mypyc.primitives.misc_ops import isinstance_bool from mypyc.primitives.set_ops import isinstance_frozenset, isinstance_set from mypyc.primitives.str_ops import ( @@ -262,7 +263,7 @@ def dict_methods_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) @specialize_function("builtins.list") -def translate_list_from_generator_call( +def translate_list_from_generator_expr( builder: IRBuilder, expr: CallExpr, callee: RefExpr ) -> Value | None: """Special case for simplest list comprehension. @@ -286,6 +287,50 @@ def translate_list_from_generator_call( return None +@specialize_function("builtins.list") +def translate_list_from_generator_call( + builder: IRBuilder, expr: CallExpr, callee: RefExpr +) -> Value | None: + """Special case for simplest list construction using one of our for_helpers. + + For example: + list(map(f, some_list/some_tuple/some_str)) + """ + if ( + len(expr.args) == 1 + and expr.arg_kinds[0] == ARG_POS + and isinstance(expr.args[0], CallExpr) + and isinstance(expr.args[0].callee, NameExpr) + and expr.args[0].callee.fullname + in ( + # TODO: make constant for these vals + "builtins.map", + "builtins.filter", + "builtins.filterfalse", + ) + ): + call_expr = expr.args[0] + itemtype = builder._analyze_iterable_item_type(call_expr) + indextype = builder.type_to_rtype(itemtype) + index = Register(indextype, "__mypyc_list_helper__", line=call_expr.line) + + result = builder.new_list_op([], expr.line) + + def body_insts() -> None: + builder.primitive_op(list_append_op, [result, index], expr.line) + + for_loop_helper( + builder=builder, + index=index, + expr=call_expr, + body_insts=body_insts, + else_insts=None, + is_async=False, + line=expr.line, + ) + return None + + @specialize_function("builtins.tuple") def translate_tuple_from_generator_call( builder: IRBuilder, expr: CallExpr, callee: RefExpr diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index c041c661741c..d0881e58c66d 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -7,6 +7,8 @@ overload, Mapping, Union, Callable, Sequence, FrozenSet, Protocol ) +from typing_extensions import Self + _T = TypeVar('_T') T_co = TypeVar('T_co', covariant=True) T_contra = TypeVar('T_contra', contravariant=True) @@ -407,3 +409,17 @@ class classmethod: pass class staticmethod: pass NotImplemented: Any = ... + +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") +_T3 = TypeVar("_T3") + +class map(Generic[_S]): + @overload + def __new__(cls, func: Callable[[_T1], _S], iterable: Iterable[_T1], /) -> Self: ... + @overload + def __new__(cls, func: Callable[[_T1, _T2], _S], iterable: Iterable[_T1], iter2: Iterable[_T2], /) -> Self: ... + @overload + def __new__(cls, func: Callable[[_T1, _T2, _T3], _S], iterable: Iterable[_T1], iter2: Iterable[_T2], iter3: Iterable[_T3], /) -> Self: ... + def __iter__(self) -> Self: ... + def __next__(self) -> _S: ... diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index f52e1af03b52..12506faed4aa 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -3541,6 +3541,220 @@ L0: r3 = box(None, 1) return r3 +[case testForMapBasic] +def f(x: int) -> int: + return x * 2 +def g(a: list[int]) -> int: + s = 0 + for x in map(f, a): + s += x + return s +[out] +def f(x): + x, r0 :: int +L0: + r0 = CPyTagged_Multiply(x, 4) + return r0 +def g(a): + a :: list + s :: int + r0 :: dict + r1 :: str + r2 :: object + r3, r4 :: native_int + r5 :: bit + r6 :: object + r7, _mypyc_map_arg_0 :: int + r8 :: object + r9, r10, x, r11 :: int + r12 :: native_int +L0: + s = 0 + r0 = __main__.globals :: static + r1 = 'f' + r2 = CPyDict_GetItem(r0, r1) + r3 = 0 +L1: + r4 = var_object_size a + r5 = r3 < r4 :: signed + if r5 goto L2 else goto L4 :: bool +L2: + r6 = list_get_item_unsafe a, r3 + r7 = unbox(int, r6) + _mypyc_map_arg_0 = r7 + r8 = box(int, _mypyc_map_arg_0) + r9 = unbox(int, r8) + r10 = f(r9) + x = r10 + r11 = CPyTagged_Add(s, x) + s = r11 +L3: + r12 = r3 + 1 + r3 = r12 + goto L1 +L4: +L5: + return s + +[case testForMapComplex] +def f(x: int, y: int, z: int) -> int: + return x + y + z +def g(a: list[int], b: list[int], c: list[int]) -> int: + s = 0 + for x in map(f, a, b, c): + s += x + return s +[out] +def f(x, y, z): + x, y, z, r0, r1 :: int +L0: + r0 = CPyTagged_Add(x, y) + r1 = CPyTagged_Add(r0, z) + return r1 +def g(a, b, c): + a, b, c :: list + s :: int + r0 :: dict + r1 :: str + r2 :: object + r3, r4, r5, r6 :: native_int + r7 :: bit + r8 :: native_int + r9 :: bit + r10 :: native_int + r11 :: bit + r12 :: object + r13, _mypyc_map_arg_0 :: int + r14 :: object + r15, _mypyc_map_arg_1 :: int + r16 :: object + r17, _mypyc_map_arg_2 :: int + r18, r19, r20 :: object + r21, r22, r23, r24, x, r25 :: int + r26, r27, r28 :: native_int +L0: + s = 0 + r0 = __main__.globals :: static + r1 = 'f' + r2 = CPyDict_GetItem(r0, r1) + r3 = 0 + r4 = 0 + r5 = 0 +L1: + r6 = var_object_size a + r7 = r3 < r6 :: signed + if r7 goto L2 else goto L6 :: bool +L2: + r8 = var_object_size b + r9 = r4 < r8 :: signed + if r9 goto L3 else goto L6 :: bool +L3: + r10 = var_object_size c + r11 = r5 < r10 :: signed + if r11 goto L4 else goto L6 :: bool +L4: + r12 = list_get_item_unsafe a, r3 + r13 = unbox(int, r12) + _mypyc_map_arg_0 = r13 + r14 = list_get_item_unsafe b, r4 + r15 = unbox(int, r14) + _mypyc_map_arg_1 = r15 + r16 = list_get_item_unsafe c, r5 + r17 = unbox(int, r16) + _mypyc_map_arg_2 = r17 + r18 = box(int, _mypyc_map_arg_0) + r19 = box(int, _mypyc_map_arg_1) + r20 = box(int, _mypyc_map_arg_2) + r21 = unbox(int, r18) + r22 = unbox(int, r19) + r23 = unbox(int, r20) + r24 = f(r21, r22, r23) + x = r24 + r25 = CPyTagged_Add(s, x) + s = r25 +L5: + r26 = r3 + 1 + r3 = r26 + r27 = r4 + 1 + r4 = r27 + r28 = r5 + 1 + r5 = r28 + goto L1 +L6: +L7: + return s + +[case testForMapComprehension] +def f(x: int, y: int) -> int: + return x * y +def g(a: list[int], b: list[int]) -> list[int]: + return [x for x in map(f, a, b)] +[out] +def f(x, y): + x, y, r0 :: int +L0: + r0 = CPyTagged_Multiply(x, y) + return r0 +def g(a, b): + a, b, r0 :: list + r1 :: dict + r2 :: str + r3 :: object + r4, r5, r6 :: native_int + r7 :: bit + r8 :: native_int + r9 :: bit + r10 :: object + r11, _mypyc_map_arg_0 :: int + r12 :: object + r13, _mypyc_map_arg_1 :: int + r14, r15 :: object + r16, r17, r18, x :: int + r19 :: object + r20 :: i32 + r21 :: bit + r22, r23 :: native_int +L0: + r0 = PyList_New(0) + r1 = __main__.globals :: static + r2 = 'f' + r3 = CPyDict_GetItem(r1, r2) + r4 = 0 + r5 = 0 +L1: + r6 = var_object_size a + r7 = r4 < r6 :: signed + if r7 goto L2 else goto L5 :: bool +L2: + r8 = var_object_size b + r9 = r5 < r8 :: signed + if r9 goto L3 else goto L5 :: bool +L3: + r10 = list_get_item_unsafe a, r4 + r11 = unbox(int, r10) + _mypyc_map_arg_0 = r11 + r12 = list_get_item_unsafe b, r5 + r13 = unbox(int, r12) + _mypyc_map_arg_1 = r13 + r14 = box(int, _mypyc_map_arg_0) + r15 = box(int, _mypyc_map_arg_1) + r16 = unbox(int, r14) + r17 = unbox(int, r15) + r18 = f(r16, r17) + x = r18 + r19 = box(int, x) + r20 = PyList_Append(r0, r19) + r21 = r20 >= 0 :: signed +L4: + r22 = r4 + 1 + r4 = r22 + r23 = r5 + 1 + r5 = r23 + goto L1 +L5: +L6: + return r0 + [case testStarArgFastPathTuple] from typing import Any, Callable def deco(fn: Callable[..., Any]) -> Callable[..., Any]: diff --git a/mypyc/test-data/run-loops.test b/mypyc/test-data/run-loops.test index 3cbb07297e6e..d0ad970438f1 100644 --- a/mypyc/test-data/run-loops.test +++ b/mypyc/test-data/run-loops.test @@ -571,3 +571,56 @@ print([x for x in native.Vector2(4, -5.2)]) [out] Vector2(x=-2, y=3.1) \[4, -5.2] + +[case testRunForMap] +def single(a: list[int]) -> int: + s = 0 + for x in map(lambda x: x + 1, a): + s += x + return s +def double(a: list[int], b: list[int]) -> int: + s = 0 + for x in map(lambda x, y: x * y, a, b): + s += x + return s +def triple(a: list[int], b: list[int], c: list[int]) -> int: + s = 0 + for x in map(lambda x, y, z: x * y + z, a, b, c): + s += x + return s +def nested(a: list[int], b: list[int], c: list[int]) -> int: + s = 0 + for x in map(lambda x, y: x + y, map(lambda x, y: x * y, a, b), c): + s += x + return s +def unpack(a: list[int], b: list[int]) -> int: + s = 0 + for x, y in map(lambda x, y: (x, y), a, b): + s += x * 10 + y + return s + +def test_single() -> None: + assert single([1, 2, 3]) == 9 +def test_double() -> None: + assert double([1, 2, 3], [4, 5, 6]) == 32 +def test_double_uneven() -> None: + # Shortest wins: only 2 items + assert double([1, 2], [10, 20, 30]) == 50 +def test_double_empty() -> None: + assert double([], []) == 0 +def test_double_empty_first() -> None: + assert double([], [1, 2]) == 0 +def test_double_empty_second() -> None: + assert double([1, 2], []) == 0 +def test_triple() -> None: + assert triple([1, 2, 3], [4, 5, 6], [7, 8, 9]) == 56 +def test_triple_uneven() -> None: + assert triple([1, 2], [10, 20, 30], [100, 200, 300]) == 350 +def test_triple_empty() -> None: + assert triple([], [], []) == 0 +def test_triple_one_empty() -> None: + assert triple([1, 2], [], [3, 4]) == 0 +def test_nested() -> None: + assert nested([1, 2, 3], [4, 5, 6], [7, 8, 9]) == 56 +def test_unpack() -> None: + assert unpack([1, 2, 3], [4, 5, 6]) == 75