diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 0a7f67c3f2..0a8fd7b59f 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -81,16 +81,16 @@ def _is_collectable_expr(node: itir.Node) -> bool: if isinstance(node, itir.FunCall): # do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be # visited, to ensure symbol dependencies are recognized correctly. - # do also not collect reduce nodes if they are left in the IR at this point, this may lead to + # do also not collect reduce, map_ and neighbors nodes if they are left in the IR at this point, this may lead to # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend - # backend (single pass eager depth first visit approach) + # backend (single pass eager depth first visit approach), see also https://github.com/GridTools/gt4py/issues/1795 # do also not collect lifts or applied lifts as they become invisible to the lift inliner # otherwise # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement # instead of an as_fieldop if cpm.is_call_to( - node, ("lift", "shift", "reduce", "map_", "index") + node, ("lift", "shift", "neighbors", "reduce", "map_", "index") ) or cpm.is_applied_lift(node): return False return True diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index ddbef679d8..ffad69c921 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -24,6 +24,7 @@ misc as ir_misc, ) from gt4py.next.iterator.transforms import ( + cse, fixed_point_transformation, inline_center_deref_lift_vars, inline_lambdas, @@ -119,7 +120,12 @@ def _prettify_as_fieldop_args( def fuse_as_fieldop( - expr: itir.Expr, eligible_args: list[bool], *, uids: utils.IDGeneratorPool + expr: itir.Expr, + eligible_args: list[bool], + *, + offset_provider_type: common.OffsetProviderType, + enable_cse: bool, + uids: utils.IDGeneratorPool, ) -> itir.Expr: assert cpm.is_applied_as_fieldop(expr) @@ -183,6 +189,11 @@ def fuse_as_fieldop( new_stencil = inline_lifts.InlineLifts().visit(new_stencil) new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) + if enable_cse: + # TODO(havogt): We should investigate how to keep the tree small without having to run CSE. + new_node = cse.CommonSubexpressionElimination.apply( + new_node, within_stencil=False, uids=uids, offset_provider_type=offset_provider_type + ) return new_node @@ -281,6 +292,8 @@ def all(self) -> FuseAsFieldOp.Transformation: enabled_transformations = Transformation.all() uids: utils.IDGeneratorPool + offset_provider_type: common.OffsetProviderType + enable_cse: bool # option to disable is mainly for testing purposes @classmethod def apply( @@ -292,6 +305,7 @@ def apply( allow_undeclared_symbols=False, within_set_at_expr: Optional[bool] = None, enabled_transformations: Optional[Transformation] = None, + enable_cse: bool = True, ): enabled_transformations = enabled_transformations or cls.enabled_transformations @@ -304,9 +318,12 @@ def apply( if within_set_at_expr is None: within_set_at_expr = not isinstance(node, itir.Program) - new_node = cls(uids=uids, enabled_transformations=enabled_transformations).visit( - node, within_set_at_expr=within_set_at_expr - ) + new_node = cls( + uids=uids, + enabled_transformations=enabled_transformations, + offset_provider_type=offset_provider_type, + enable_cse=enable_cse, + ).visit(node, within_set_at_expr=within_set_at_expr) # The `FuseAsFieldOp` pass does not fully preserve the type information yet. In particular # for the generated lifts this is tricky and error-prone. For simplicity, we just reinfer # everything here ensuring later passes can use the information. @@ -391,7 +408,13 @@ def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs): ] if any(eligible_els): return self.visit( - fuse_as_fieldop(node, eligible_els, uids=self.uids), + fuse_as_fieldop( + node, + eligible_els, + uids=self.uids, + offset_provider_type=self.offset_provider_type, + enable_cse=self.enable_cse, + ), **{**kwargs, "recurse": False}, ) return None diff --git a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py index 22d7d790bc..237dfb769a 100644 --- a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py +++ b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py @@ -10,7 +10,7 @@ import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm from gt4py import eve -from gt4py.next import utils +from gt4py.next import common, utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import fuse_as_fieldop, inline_lambdas, trace_shifts from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs @@ -33,11 +33,17 @@ def _dynamic_shift_args(node: itir.Expr) -> None | list[bool]: @dataclasses.dataclass class InlineDynamicShifts(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + offset_provider_type: common.OffsetProviderType uids: utils.IDGeneratorPool @classmethod - def apply(cls, node: itir.Program, uids: utils.IDGeneratorPool): - return cls(uids=uids).visit(node) + def apply( + cls, + node: itir.Program, + offset_provider_type: common.OffsetProviderType, + uids: utils.IDGeneratorPool, + ): + return cls(offset_provider_type=offset_provider_type, uids=uids).visit(node) def visit_FunCall(self, node: itir.FunCall, **kwargs): node = self.generic_visit(node, **kwargs) @@ -65,6 +71,12 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): for inp, dynamic_shift_arg in zip(node.args, dynamic_shift_args, strict=True) ] if any(fuse_args): - return fuse_as_fieldop.fuse_as_fieldop(node, fuse_args, uids=self.uids) + return fuse_as_fieldop.fuse_as_fieldop( + node, + fuse_args, + uids=self.uids, + offset_provider_type=self.offset_provider_type, + enable_cse=True, + ) return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 9da51c167d..1fb30b096d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -79,7 +79,7 @@ def apply_common_transforms( ir, uids=uids, offset_provider_type=offset_provider_type ) # domain inference does not support dead-code ir = inline_dynamic_shifts.InlineDynamicShifts.apply( - ir, uids=uids + ir, offset_provider_type=offset_provider_type, uids=uids ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) @@ -183,7 +183,7 @@ def apply_fieldview_transforms( ir, offset_provider_type=offset_provider_type, uids=uids ) ir = inline_dynamic_shifts.InlineDynamicShifts.apply( - ir, uids=uids + ir, offset_provider_type=offset_provider_type, uids=uids ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py index 7f15ccd404..0817b5f19d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -25,7 +25,9 @@ def test_inline_dynamic_shift_as_fieldop_arg(uids): ) )("inp", "offset_field") - actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee, uids=uids) + actual = inline_dynamic_shifts.InlineDynamicShifts.apply( + testee, offset_provider_type={}, uids=uids + ) assert actual == expected @@ -42,5 +44,7 @@ def test_inline_dynamic_shift_let_var(uids): ) )("inp", "offset_field") - actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee, uids=uids) + actual = inline_dynamic_shifts.InlineDynamicShifts.apply( + testee, offset_provider_type={}, uids=uids + ) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 8258180ad2..2e9f97d693 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -11,7 +11,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils from gt4py.next.iterator.transforms import ( - fuse_as_fieldop as fasfop, + fuse_as_fieldop, collapse_tuple as ct, ) from gt4py.next.type_system import type_specifications as ts @@ -40,8 +40,8 @@ def test_trivial(uids: utils.IDGeneratorPool): ), d, )(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, uids=uids ) assert actual == expected @@ -50,8 +50,8 @@ def test_trivial_literal(uids: utils.IDGeneratorPool): d = im.domain("cartesian_domain", {}) testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3) expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, uids=uids ) assert actual == expected @@ -69,8 +69,8 @@ def test_trivial_same_arg_twice(uids: utils.IDGeneratorPool): ), d, )(im.ref("inp1", field_type), im.ref("inp2", field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == expected @@ -94,8 +94,8 @@ def test_tuple_arg(uids: utils.IDGeneratorPool): ), d, )() - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == expected @@ -114,8 +114,8 @@ def test_symref_used_twice(uids: utils.IDGeneratorPool): ), d, )("inp1", "inp2") - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == expected @@ -129,8 +129,12 @@ def test_no_inline(uids: utils.IDGeneratorPool): ), d1, )(im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type))) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, + offset_provider_type={"IOff": IDim}, + allow_undeclared_symbols=True, + enable_cse=False, + uids=uids, ) assert actual == testee @@ -153,8 +157,8 @@ def test_staged_inlining(uids: utils.IDGeneratorPool): ), d, )(im.ref("a", field_type), im.ref("b", field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == expected @@ -169,8 +173,8 @@ def test_make_tuple_fusion_trivial(uids: utils.IDGeneratorPool): im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), d, )(im.ref("a", field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) # simplify to remove unnecessary make_tuple call `{v[0], v[1]}(actual)` actual_simplified = ct.CollapseTuple.apply( @@ -189,8 +193,8 @@ def test_make_tuple_fusion_symref(uids: utils.IDGeneratorPool): im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), d, )(im.ref("a", field_type), im.ref("b", field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) # simplify to remove unnecessary make_tuple call actual_simplified = ct.CollapseTuple.apply( @@ -209,8 +213,8 @@ def test_make_tuple_fusion_symref_same_ref(uids: utils.IDGeneratorPool): im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), d, )(im.ref("a", field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) # simplify to remove unnecessary make_tuple call actual_simplified = ct.CollapseTuple.apply( @@ -234,8 +238,8 @@ def test_make_tuple_nested(uids: utils.IDGeneratorPool): ), d, )(im.ref("a", field_type), im.ref("b", field_type), im.ref("c", field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) # simplify to remove unnecessary make_tuple call actual_simplified = ct.CollapseTuple.apply( @@ -276,8 +280,8 @@ def test_make_tuple_fusion_different_domains(uids: utils.IDGeneratorPool): im.tuple_get(1, "__fasfop_1"), ) ) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == expected @@ -312,8 +316,12 @@ def test_partial_inline(uids: utils.IDGeneratorPool): "inp1", "inp2", ) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, + offset_provider_type={"IOff": IDim}, + allow_undeclared_symbols=True, + enable_cse=False, + uids=uids, ) assert actual == expected @@ -336,8 +344,8 @@ def test_chained_fusion(uids: utils.IDGeneratorPool): ), d, )(im.ref("inp1", field_type), im.ref("inp2", field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == expected @@ -353,8 +361,8 @@ def test_inline_as_fieldop_with_list_dtype(uids: utils.IDGeneratorPool): expected = im.as_fieldop( im.lambda_("inp")(im.call(im.call("reduce")("plus", 0))(im.deref("inp"))), d )(im.ref("inp", list_field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == expected @@ -364,8 +372,8 @@ def test_inline_into_scan(uids: utils.IDGeneratorPool): scan = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0) testee = im.as_fieldop(scan, d)(im.as_fieldop("deref")(im.ref("a", field_type))) expected = im.as_fieldop(scan, d)(im.ref("a", field_type)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == expected @@ -377,8 +385,8 @@ def test_no_inline_into_scan(uids: utils.IDGeneratorPool): ) scan = im.as_fieldop(scan_stencil, d)(im.ref("a", field_type)) testee = im.as_fieldop(im.lambda_("arg")(im.deref("arg")), d)(scan) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == testee @@ -390,7 +398,7 @@ def test_opage_arg_deduplication(uids: utils.IDGeneratorPool): im.lambda_("__arg1")(im.plus(im.deref("__arg1"), im.deref("__arg1"))), d, )(im.index(IDim)) - actual = fasfop.FuseAsFieldOp.apply( - testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids ) assert actual == expected