From 3c353ff30ec2e487b564a1d0cc2c5523277fffaa Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Sep 2025 00:14:55 +0200 Subject: [PATCH 1/8] fix[next]: apply cse in fuse_as_fieldop --- src/gt4py/next/iterator/transforms/cse.py | 2 +- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 2fcbd5df0d..8c98c7acc3 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -396,7 +396,7 @@ def extract_subexpression( if not eligible_ids: continue - expr_id = uid_generator.sequential_id() + expr_id = uid_generator.random_id() # TODO: undo, but make sure we don't get collisions extracted[itir.Sym(id=expr_id)] = expr expr_ref = itir.SymRef(id=expr_id) for id_ in eligible_ids: diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 4b3a258396..c119d5c6fd 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -25,6 +25,7 @@ misc as ir_misc, ) from gt4py.next.iterator.transforms import ( + cse, fixed_point_transformation, inline_center_deref_lift_vars, inline_lambdas, @@ -182,6 +183,7 @@ def fuse_as_fieldop( new_stencil, opcount_preserving=True, force_inline_lift_args=True ) new_stencil = inline_lifts.InlineLifts().visit(new_stencil) + new_stencil = cse.CommonSubexpressionElimination.apply(new_stencil, within_stencil=True) new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) From 5cf6e49ef9122bcb7a08f565e06e9d2f3771ba03 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Sep 2025 07:41:09 +0200 Subject: [PATCH 2/8] fix name collision in cse --- src/gt4py/next/iterator/transforms/cse.py | 13 ++++++++----- .../next/iterator/transforms/fuse_as_fieldop.py | 4 +++- src/gt4py/next/iterator/transforms/pass_manager.py | 5 ++++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 8c98c7acc3..e5ecd7d49a 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -396,7 +396,7 @@ def extract_subexpression( if not eligible_ids: continue - expr_id = uid_generator.random_id() # TODO: undo, but make sure we don't get collisions + expr_id = uid_generator.sequential_id(prefix="_cs") extracted[itir.Sym(id=expr_id)] = expr expr_ref = itir.SymRef(id=expr_id) for id_ in eligible_ids: @@ -435,9 +435,7 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): # we use one UID generator per instance such that the generated ids are # stable across multiple runs (required for caching to properly work) - uids: UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: UIDGenerator(prefix="_cs") - ) + uids: UIDGenerator = dataclasses.field(repr=False) collect_all: bool = dataclasses.field(default=False) @@ -447,6 +445,8 @@ def apply( node: ProgramOrExpr, within_stencil: bool | None = None, offset_provider_type: common.OffsetProviderType | None = None, + *, + uids: UIDGenerator | None = None, ) -> ProgramOrExpr: is_program = isinstance(node, itir.Program) if is_program: @@ -457,11 +457,14 @@ def apply( "The expression's context must be specified using `within_stencil`." ) + if not uids: + uids = UIDGenerator() + offset_provider_type = offset_provider_type or {} node = itir_type_inference.infer( node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=not is_program ) - return cls().visit(node, within_stencil=within_stencil) + return cls(uids=uids).visit(node, within_stencil=within_stencil) def generic_visit(self, node, **kwargs): if cpm.is_call_to(node, "as_fieldop"): diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index c119d5c6fd..1e188b620e 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -183,7 +183,9 @@ def fuse_as_fieldop( new_stencil, opcount_preserving=True, force_inline_lift_args=True ) new_stencil = inline_lifts.InlineLifts().visit(new_stencil) - new_stencil = cse.CommonSubexpressionElimination.apply(new_stencil, within_stencil=True) + new_stencil = cse.CommonSubexpressionElimination.apply( + new_stencil, within_stencil=True, uids=uids + ) new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 9c5e154011..d4c7d2c2e5 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -65,6 +65,7 @@ def apply_common_transforms( tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") mergeasfop_uids = eve_utils.UIDGenerator() collapse_tuple_uids = eve_utils.UIDGenerator() + cse_uids = eve_utils.UIDGenerator() ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) @@ -126,7 +127,9 @@ def apply_common_transforms( # breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) + ir = CommonSubexpressionElimination.apply( + ir, offset_provider_type=offset_provider_type, uids=cse_uids + ) ir = MergeLet().visit(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True) From 6ffd03cc7bf57ccde10214224722fb34beb3d986 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Sep 2025 08:27:27 +0200 Subject: [PATCH 3/8] pass offset_provider_type --- .../iterator/transforms/fixed_point_transformation.py | 5 ++++- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 9 ++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py index 3818f3864a..7427450c4d 100644 --- a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py +++ b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py @@ -42,7 +42,10 @@ def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: def _post_transform(self, node: ir.Node, new_node: ir.Node) -> ir.Node: if self.REINFER_TYPES: - itir_type_inference.reinfer(new_node) + kwargs = {} + if hasattr(self, "offset_provider_type"): + kwargs["offset_provider_type"] = self.offset_provider_type + itir_type_inference.reinfer(new_node, **kwargs) self._preserve_annex(node, new_node) return new_node diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 1e188b620e..d2f42bcf65 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -282,6 +282,7 @@ def all(self) -> FuseAsFieldOp.Transformation: enabled_transformations = Transformation.all() uids: eve_utils.UIDGenerator + offset_provider_type: common.OffsetProviderType @classmethod def apply( @@ -308,9 +309,11 @@ def apply( if not uids: uids = eve_utils.UIDGenerator() - new_node = cls(uids=uids, enabled_transformations=enabled_transformations).visit( - node, within_set_at_expr=within_set_at_expr - ) + new_node = cls( + uids=uids, + enabled_transformations=enabled_transformations, + offset_provider_type=offset_provider_type, + ).visit(node, within_set_at_expr=within_set_at_expr) # The `FuseAsFieldOp` pass does not fully preserve the type information yet. In particular # for the generated lifts this is tricky and error-prone. For simplicity, we just reinfer # everything here ensuring later passes can use the information. From 5c90bcf21b379bf354e968217b4ad4edc2aab5be Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 11 Sep 2025 09:14:41 +0200 Subject: [PATCH 4/8] dont cse collect neighbors --- src/gt4py/next/iterator/transforms/cse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index e5ecd7d49a..7bc986f00b 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -90,7 +90,7 @@ def _is_collectable_expr(node: itir.Node) -> bool: # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement # instead of an as_fieldop if cpm.is_call_to( - node, ("lift", "shift", "reduce", "map_", "index") + node, ("lift", "shift", "neighbors", "reduce", "map_", "index") ) or cpm.is_applied_lift(node): return False return True From 13a85970e9bebc248a537b160ce7fcb4f6ded7c8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 27 Nov 2025 13:52:11 +0100 Subject: [PATCH 5/8] fix cse type reinfer --- .../transforms/fixed_point_transformation.py | 5 +---- .../iterator/transforms/fuse_as_fieldop.py | 20 ++++++++++++++----- .../transforms/inline_dynamic_shifts.py | 18 ++++++++++++++--- .../next/iterator/transforms/pass_manager.py | 4 ++-- .../test_inline_dynamic_shifts.py | 4 ++-- 5 files changed, 35 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py index 7427450c4d..3818f3864a 100644 --- a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py +++ b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py @@ -42,10 +42,7 @@ def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: def _post_transform(self, node: ir.Node, new_node: ir.Node) -> ir.Node: if self.REINFER_TYPES: - kwargs = {} - if hasattr(self, "offset_provider_type"): - kwargs["offset_provider_type"] = self.offset_provider_type - itir_type_inference.reinfer(new_node, **kwargs) + itir_type_inference.reinfer(new_node) self._preserve_annex(node, new_node) return new_node diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index d2f42bcf65..2a93f992d5 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -121,7 +121,11 @@ def _prettify_as_fieldop_args( def fuse_as_fieldop( - expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator + expr: itir.Expr, + eligible_args: list[bool], + *, + offset_provider_type: common.OffsetProviderType, + uids: eve_utils.UIDGenerator, ) -> itir.Expr: assert cpm.is_applied_as_fieldop(expr) @@ -183,11 +187,12 @@ def fuse_as_fieldop( new_stencil, opcount_preserving=True, force_inline_lift_args=True ) new_stencil = inline_lifts.InlineLifts().visit(new_stencil) - new_stencil = cse.CommonSubexpressionElimination.apply( - new_stencil, within_stencil=True, uids=uids - ) new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) + # TODO(havogt): We should investigate how to keep the tree small without having to run CSE. + new_node = cse.CommonSubexpressionElimination.apply( + new_node, within_stencil=False, uids=uids, offset_provider_type=offset_provider_type + ) return new_node @@ -398,7 +403,12 @@ def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs): ] if any(eligible_els): return self.visit( - fuse_as_fieldop(node, eligible_els, uids=self.uids), + fuse_as_fieldop( + node, + eligible_els, + uids=self.uids, + offset_provider_type=self.offset_provider_type, + ), **{**kwargs, "recurse": False}, ) return None diff --git a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py index f24326efe6..4158ffa2d9 100644 --- a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py +++ b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py @@ -12,6 +12,7 @@ import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import fuse_as_fieldop, inline_lambdas, trace_shifts from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs @@ -34,14 +35,20 @@ def _dynamic_shift_args(node: itir.Expr) -> None | list[bool]: @dataclasses.dataclass class InlineDynamicShifts(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + offset_provider_type: common.OffsetProviderType uids: eve_utils.UIDGenerator @classmethod - def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + def apply( + cls, + node: itir.Program, + offset_provider_type: common.OffsetProviderType, + uids: Optional[eve_utils.UIDGenerator] = None, + ): if not uids: uids = eve_utils.UIDGenerator() - return cls(uids=uids).visit(node) + return cls(offset_provider_type=offset_provider_type, uids=uids).visit(node) def visit_FunCall(self, node: itir.FunCall, **kwargs): node = self.generic_visit(node, **kwargs) @@ -69,6 +76,11 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): for inp, dynamic_shift_arg in zip(node.args, dynamic_shift_args, strict=True) ] if any(fuse_args): - return fuse_as_fieldop.fuse_as_fieldop(node, fuse_args, uids=self.uids) + return fuse_as_fieldop.fuse_as_fieldop( + node, + fuse_args, + uids=self.uids, + offset_provider_type=self.offset_provider_type, + ) return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index f3b56f7f98..01a4f6ed75 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -83,7 +83,7 @@ def apply_common_transforms( ir, collapse_tuple_uids=collapse_tuple_uids, offset_provider_type=offset_provider_type ) # domain inference does not support dead-code ir = inline_dynamic_shifts.InlineDynamicShifts.apply( - ir + ir, offset_provider_type=offset_provider_type ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) @@ -183,7 +183,7 @@ def apply_fieldview_transforms( ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination(ir, offset_provider_type=offset_provider_type) ir = inline_dynamic_shifts.InlineDynamicShifts.apply( - ir + ir, offset_provider_type=offset_provider_type ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py index ff7a761c5a..205b8447aa 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -27,7 +27,7 @@ def test_inline_dynamic_shift_as_fieldop_arg(): ) )("inp", "offset_field") - actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee, offset_provider_type={}) assert actual == expected @@ -44,5 +44,5 @@ def test_inline_dynamic_shift_let_var(): ) )("inp", "offset_field") - actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee, offset_provider_type={}) assert actual == expected From 0f6737cafc198e699b18f41530f5deb150fc41eb Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 27 Nov 2025 15:45:38 +0100 Subject: [PATCH 6/8] fix test references --- .../iterator/transforms/fuse_as_fieldop.py | 14 +++++--- .../transforms/inline_dynamic_shifts.py | 1 + .../transforms_tests/test_cse.py | 6 ++-- .../transforms_tests/test_fuse_as_fieldop.py | 32 +++++++++---------- 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 2a93f992d5..74da4c3e2d 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -125,6 +125,7 @@ def fuse_as_fieldop( eligible_args: list[bool], *, offset_provider_type: common.OffsetProviderType, + enable_cse: bool, uids: eve_utils.UIDGenerator, ) -> itir.Expr: assert cpm.is_applied_as_fieldop(expr) @@ -189,10 +190,11 @@ def fuse_as_fieldop( new_stencil = inline_lifts.InlineLifts().visit(new_stencil) new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) - # TODO(havogt): We should investigate how to keep the tree small without having to run CSE. - new_node = cse.CommonSubexpressionElimination.apply( - new_node, within_stencil=False, uids=uids, offset_provider_type=offset_provider_type - ) + if enable_cse: + # TODO(havogt): We should investigate how to keep the tree small without having to run CSE. + new_node = cse.CommonSubexpressionElimination.apply( + new_node, within_stencil=False, uids=uids, offset_provider_type=offset_provider_type + ) return new_node @@ -288,6 +290,7 @@ def all(self) -> FuseAsFieldOp.Transformation: uids: eve_utils.UIDGenerator offset_provider_type: common.OffsetProviderType + enable_cse: bool # option to disable is mainly for testing purposes @classmethod def apply( @@ -299,6 +302,7 @@ def apply( allow_undeclared_symbols=False, within_set_at_expr: Optional[bool] = None, enabled_transformations: Optional[Transformation] = None, + enable_cse: bool = True, ): enabled_transformations = enabled_transformations or cls.enabled_transformations @@ -318,6 +322,7 @@ def apply( uids=uids, enabled_transformations=enabled_transformations, offset_provider_type=offset_provider_type, + enable_cse=enable_cse, ).visit(node, within_set_at_expr=within_set_at_expr) # The `FuseAsFieldOp` pass does not fully preserve the type information yet. In particular # for the generated lifts this is tricky and error-prone. For simplicity, we just reinfer @@ -408,6 +413,7 @@ def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs): eligible_els, uids=self.uids, offset_provider_type=self.offset_provider_type, + enable_cse=self.enable_cse, ), **{**kwargs, "recurse": False}, ) diff --git a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py index 4158ffa2d9..924537890a 100644 --- a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py +++ b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py @@ -81,6 +81,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): fuse_args, uids=self.uids, offset_provider_type=self.offset_provider_type, + enable_cse=True, ) return node diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index f618ba409d..8bbbb80f0d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -228,12 +228,12 @@ def is_let(node: ir.Expr): b = 2 c = a + b d = 3 - _let_result_1 = c + d - return _let_result_1 + 4 + _cs_1 = c + d + return _cs_1 + 4 """ ).strip() - uid_gen = UIDGenerator(prefix="_let_result") + uid_gen = UIDGenerator() def convert_to_assignment_stmt_form(node: ir.Expr) -> tuple[list[tuple[str, ir.Expr]], ir.Expr]: assignment_stmts: list[tuple[str, ir.Expr]] = [] diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 58febf03c9..337c74f202 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -68,7 +68,7 @@ def test_trivial_same_arg_twice(): d, )(im.ref("inp1", field_type), im.ref("inp2", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected @@ -93,7 +93,7 @@ def test_tuple_arg(): d, )() actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected @@ -113,7 +113,7 @@ def test_symref_used_twice(): d, )("inp1", "inp2") actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected @@ -128,7 +128,7 @@ def test_no_inline(): d1, )(im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type))) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == testee @@ -152,7 +152,7 @@ def test_staged_inlining(): d, )(im.ref("a", field_type), im.ref("b", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected @@ -168,7 +168,7 @@ def test_make_tuple_fusion_trivial(): d, )(im.ref("a", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) # simplify to remove unnecessary make_tuple call `{v[0], v[1]}(actual)` actual_simplified = collapse_tuple.CollapseTuple.apply( @@ -188,7 +188,7 @@ def test_make_tuple_fusion_symref(): d, )(im.ref("a", field_type), im.ref("b", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) # simplify to remove unnecessary make_tuple call actual_simplified = collapse_tuple.CollapseTuple.apply( @@ -208,7 +208,7 @@ def test_make_tuple_fusion_symref_same_ref(): d, )(im.ref("a", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) # simplify to remove unnecessary make_tuple call actual_simplified = collapse_tuple.CollapseTuple.apply( @@ -233,7 +233,7 @@ def test_make_tuple_nested(): d, )(im.ref("a", field_type), im.ref("b", field_type), im.ref("c", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) # simplify to remove unnecessary make_tuple call actual_simplified = collapse_tuple.CollapseTuple.apply( @@ -275,7 +275,7 @@ def test_make_tuple_fusion_different_domains(): ) ) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected @@ -311,7 +311,7 @@ def test_partial_inline(): "inp2", ) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected @@ -335,7 +335,7 @@ def test_chained_fusion(): d, )(im.ref("inp1", field_type), im.ref("inp2", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected @@ -352,7 +352,7 @@ def test_inline_as_fieldop_with_list_dtype(): im.lambda_("inp")(im.call(im.call("reduce")("plus", 0))(im.deref("inp"))), d )(im.ref("inp", list_field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected @@ -363,7 +363,7 @@ def test_inline_into_scan(): testee = im.as_fieldop(scan, d)(im.as_fieldop("deref")(im.ref("a", field_type))) expected = im.as_fieldop(scan, d)(im.ref("a", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected @@ -376,7 +376,7 @@ def test_no_inline_into_scan(): scan = im.as_fieldop(scan_stencil, d)(im.ref("a", field_type)) testee = im.as_fieldop(im.lambda_("arg")(im.deref("arg")), d)(scan) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == testee @@ -389,6 +389,6 @@ def test_opage_arg_deduplication(): d, )(im.index(IDim)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider_type={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False ) assert actual == expected From d82fd067cf26d635d043c5e8e00e93943ab6c771 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 28 Nov 2025 13:11:40 +0100 Subject: [PATCH 7/8] make the unique ids uniquer --- src/gt4py/next/iterator/transforms/cse.py | 27 +++++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 7bc986f00b..6826f3977b 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -262,6 +262,19 @@ def visit_SymRef( state.used_symbol_ids.add(id(symtable[node.id])) +def _unique_id(uid_generator: UIDGenerator) -> str: + """ + Generate a unique id with '_cs' prefix using the given uid generator. + + In case cse is used from a different context where the uid generator already has a prefix. + """ + prefix = uid_generator.prefix or "" + if prefix != "_cs": + prefix = f"{prefix}_cs" + + return uid_generator.sequential_id(prefix=prefix) + + def extract_subexpression( node: itir.Expr, predicate: Callable[[itir.Expr, int], bool], @@ -299,10 +312,10 @@ def extract_subexpression( ... expr, predicate, UIDGenerator(prefix="_subexpr") ... ) >>> print(new_expr) - _subexpr_1 + (_subexpr_1 + z) + _subexpr_cs_1 + (_subexpr_cs_1 + z) >>> for sym, subexpr in extracted_subexprs.items(): ... print(f"`{sym}`: `{subexpr}`") - `_subexpr_1`: `x + y` + `_subexpr_cs_1`: `x + y` The order of the extraction can be configured using `deepest_expr_first`. By default, the nodes closer to the root are eliminated first: @@ -315,10 +328,10 @@ def extract_subexpression( ... expr, predicate, UIDGenerator(prefix="_subexpr"), deepest_expr_first=False ... ) >>> print(new_expr) - _subexpr_1 + _subexpr_1 + _subexpr_cs_1 + _subexpr_cs_1 >>> for sym, subexpr in extracted_subexprs.items(): ... print(f"`{sym}`: `{subexpr}`") - `_subexpr_1`: `x + y + (x + y)` + `_subexpr_cs_1`: `x + y + (x + y)` Since `(x+y)` is a child of one of the expressions it is ignored: @@ -339,10 +352,10 @@ def extract_subexpression( ... deepest_expr_first=True, ... ) >>> print(new_expr) - _subexpr_1 + _subexpr_1 + (_subexpr_1 + _subexpr_1) + _subexpr_cs_1 + _subexpr_cs_1 + (_subexpr_cs_1 + _subexpr_cs_1) >>> for sym, subexpr in extracted_subexprs.items(): ... print(f"`{sym}`: `{subexpr}`") - `_subexpr_1`: `x + y` + `_subexpr_cs_1`: `x + y` Note that this requires `once_only` to be set right now. """ @@ -396,7 +409,7 @@ def extract_subexpression( if not eligible_ids: continue - expr_id = uid_generator.sequential_id(prefix="_cs") + expr_id = _unique_id(uid_generator) extracted[itir.Sym(id=expr_id)] = expr expr_ref = itir.SymRef(id=expr_id) for id_ in eligible_ids: From 40791356b2c0a598c6df2775c2101fd4ceada0ca Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 3 Dec 2025 10:14:35 +0100 Subject: [PATCH 8/8] add comment --- src/gt4py/next/iterator/transforms/cse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 072b3c6764..0a8fd7b59f 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -81,10 +81,10 @@ def _is_collectable_expr(node: itir.Node) -> bool: if isinstance(node, itir.FunCall): # do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be # visited, to ensure symbol dependencies are recognized correctly. - # do also not collect reduce nodes if they are left in the IR at this point, this may lead to + # do also not collect reduce, map_ and neighbors nodes if they are left in the IR at this point, this may lead to # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend - # backend (single pass eager depth first visit approach) + # backend (single pass eager depth first visit approach), see also https://github.com/GridTools/gt4py/issues/1795 # do also not collect lifts or applied lifts as they become invisible to the lift inliner # otherwise # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement