Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,16 @@ def _is_collectable_expr(node: itir.Node) -> bool:
if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
# do also not collect reduce nodes if they are left in the IR at this point, this may lead to
# do also not collect reduce, map_ and neighbors nodes if they are left in the IR at this point, this may lead to
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach)
# backend (single pass eager depth first visit approach), see also https://github.com/GridTools/gt4py/issues/1795
# do also not collect lifts or applied lifts as they become invisible to the lift inliner
# otherwise
# do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement
# instead of an as_fieldop
if cpm.is_call_to(
node, ("lift", "shift", "reduce", "map_", "index")
node, ("lift", "shift", "neighbors", "reduce", "map_", "index")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you extend the comment above to document why neighbors should not be collected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if this can be undone, after unroll_reduce #2267 ?

) or cpm.is_applied_lift(node):
return False
return True
Expand Down
33 changes: 28 additions & 5 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
misc as ir_misc,
)
from gt4py.next.iterator.transforms import (
cse,
fixed_point_transformation,
inline_center_deref_lift_vars,
inline_lambdas,
Expand Down Expand Up @@ -119,7 +120,12 @@ def _prettify_as_fieldop_args(


def fuse_as_fieldop(
expr: itir.Expr, eligible_args: list[bool], *, uids: utils.IDGeneratorPool
expr: itir.Expr,
eligible_args: list[bool],
*,
offset_provider_type: common.OffsetProviderType,
enable_cse: bool,
uids: utils.IDGeneratorPool,
) -> itir.Expr:
assert cpm.is_applied_as_fieldop(expr)

Expand Down Expand Up @@ -183,6 +189,11 @@ def fuse_as_fieldop(
new_stencil = inline_lifts.InlineLifts().visit(new_stencil)

new_node = im.as_fieldop(new_stencil, domain)(*new_args.values())
if enable_cse:
# TODO(havogt): We should investigate how to keep the tree small without having to run CSE.
new_node = cse.CommonSubexpressionElimination.apply(
new_node, within_stencil=False, uids=uids, offset_provider_type=offset_provider_type
)

return new_node

Expand Down Expand Up @@ -281,6 +292,8 @@ def all(self) -> FuseAsFieldOp.Transformation:
enabled_transformations = Transformation.all()

uids: utils.IDGeneratorPool
offset_provider_type: common.OffsetProviderType
enable_cse: bool # option to disable is mainly for testing purposes

@classmethod
def apply(
Expand All @@ -292,6 +305,7 @@ def apply(
allow_undeclared_symbols=False,
within_set_at_expr: Optional[bool] = None,
enabled_transformations: Optional[Transformation] = None,
enable_cse: bool = True,
):
enabled_transformations = enabled_transformations or cls.enabled_transformations

Expand All @@ -304,9 +318,12 @@ def apply(
if within_set_at_expr is None:
within_set_at_expr = not isinstance(node, itir.Program)

new_node = cls(uids=uids, enabled_transformations=enabled_transformations).visit(
node, within_set_at_expr=within_set_at_expr
)
new_node = cls(
uids=uids,
enabled_transformations=enabled_transformations,
offset_provider_type=offset_provider_type,
enable_cse=enable_cse,
).visit(node, within_set_at_expr=within_set_at_expr)
# The `FuseAsFieldOp` pass does not fully preserve the type information yet. In particular
# for the generated lifts this is tricky and error-prone. For simplicity, we just reinfer
# everything here ensuring later passes can use the information.
Expand Down Expand Up @@ -391,7 +408,13 @@ def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs):
]
if any(eligible_els):
return self.visit(
fuse_as_fieldop(node, eligible_els, uids=self.uids),
fuse_as_fieldop(
node,
eligible_els,
uids=self.uids,
offset_provider_type=self.offset_provider_type,
enable_cse=self.enable_cse,
),
**{**kwargs, "recurse": False},
)
return None
Expand Down
20 changes: 16 additions & 4 deletions src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm
from gt4py import eve
from gt4py.next import utils
from gt4py.next import common, utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import fuse_as_fieldop, inline_lambdas, trace_shifts
from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs
Expand All @@ -33,11 +33,17 @@ def _dynamic_shift_args(node: itir.Expr) -> None | list[bool]:

@dataclasses.dataclass
class InlineDynamicShifts(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait):
offset_provider_type: common.OffsetProviderType
uids: utils.IDGeneratorPool

@classmethod
def apply(cls, node: itir.Program, uids: utils.IDGeneratorPool):
return cls(uids=uids).visit(node)
def apply(
cls,
node: itir.Program,
offset_provider_type: common.OffsetProviderType,
uids: utils.IDGeneratorPool,
):
return cls(offset_provider_type=offset_provider_type, uids=uids).visit(node)

def visit_FunCall(self, node: itir.FunCall, **kwargs):
node = self.generic_visit(node, **kwargs)
Expand Down Expand Up @@ -65,6 +71,12 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
for inp, dynamic_shift_arg in zip(node.args, dynamic_shift_args, strict=True)
]
if any(fuse_args):
return fuse_as_fieldop.fuse_as_fieldop(node, fuse_args, uids=self.uids)
return fuse_as_fieldop.fuse_as_fieldop(
node,
fuse_args,
uids=self.uids,
offset_provider_type=self.offset_provider_type,
enable_cse=True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this option mainly for testing, but now we could disable it here and we wouldn't have to pass the offset_provider_type. Not sure what to do...

)

return node
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def apply_common_transforms(
ir, uids=uids, offset_provider_type=offset_provider_type
) # domain inference does not support dead-code
ir = inline_dynamic_shifts.InlineDynamicShifts.apply(
ir, uids=uids
ir, offset_provider_type=offset_provider_type, uids=uids
) # domain inference does not support dynamic offsets yet
ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
Expand Down Expand Up @@ -183,7 +183,7 @@ def apply_fieldview_transforms(
ir, offset_provider_type=offset_provider_type, uids=uids
)
ir = inline_dynamic_shifts.InlineDynamicShifts.apply(
ir, uids=uids
ir, offset_provider_type=offset_provider_type, uids=uids
) # domain inference does not support dynamic offsets yet

ir = infer_domain_ops.InferDomainOps.apply(ir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def test_inline_dynamic_shift_as_fieldop_arg(uids):
)
)("inp", "offset_field")

actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee, uids=uids)
actual = inline_dynamic_shifts.InlineDynamicShifts.apply(
testee, offset_provider_type={}, uids=uids
)
assert actual == expected


Expand All @@ -42,5 +44,7 @@ def test_inline_dynamic_shift_let_var(uids):
)
)("inp", "offset_field")

actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee, uids=uids)
actual = inline_dynamic_shifts.InlineDynamicShifts.apply(
testee, offset_provider_type={}, uids=uids
)
assert actual == expected
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils
from gt4py.next.iterator.transforms import (
fuse_as_fieldop as fasfop,
fuse_as_fieldop,
collapse_tuple as ct,
)
from gt4py.next.type_system import type_specifications as ts
Expand Down Expand Up @@ -40,8 +40,8 @@ def test_trivial(uids: utils.IDGeneratorPool):
),
d,
)(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, uids=uids
)
assert actual == expected

Expand All @@ -50,8 +50,8 @@ def test_trivial_literal(uids: utils.IDGeneratorPool):
d = im.domain("cartesian_domain", {})
testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3)
expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)()
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, uids=uids
)
assert actual == expected

Expand All @@ -69,8 +69,8 @@ def test_trivial_same_arg_twice(uids: utils.IDGeneratorPool):
),
d,
)(im.ref("inp1", field_type), im.ref("inp2", field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == expected

Expand All @@ -94,8 +94,8 @@ def test_tuple_arg(uids: utils.IDGeneratorPool):
),
d,
)()
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == expected

Expand All @@ -114,8 +114,8 @@ def test_symref_used_twice(uids: utils.IDGeneratorPool):
),
d,
)("inp1", "inp2")
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == expected

Expand All @@ -129,8 +129,12 @@ def test_no_inline(uids: utils.IDGeneratorPool):
),
d1,
)(im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee,
offset_provider_type={"IOff": IDim},
allow_undeclared_symbols=True,
enable_cse=False,
uids=uids,
)
assert actual == testee

Expand All @@ -153,8 +157,8 @@ def test_staged_inlining(uids: utils.IDGeneratorPool):
),
d,
)(im.ref("a", field_type), im.ref("b", field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == expected

Expand All @@ -169,8 +173,8 @@ def test_make_tuple_fusion_trivial(uids: utils.IDGeneratorPool):
im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))),
d,
)(im.ref("a", field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
# simplify to remove unnecessary make_tuple call `{v[0], v[1]}(actual)`
actual_simplified = ct.CollapseTuple.apply(
Expand All @@ -189,8 +193,8 @@ def test_make_tuple_fusion_symref(uids: utils.IDGeneratorPool):
im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))),
d,
)(im.ref("a", field_type), im.ref("b", field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
# simplify to remove unnecessary make_tuple call
actual_simplified = ct.CollapseTuple.apply(
Expand All @@ -209,8 +213,8 @@ def test_make_tuple_fusion_symref_same_ref(uids: utils.IDGeneratorPool):
im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))),
d,
)(im.ref("a", field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
# simplify to remove unnecessary make_tuple call
actual_simplified = ct.CollapseTuple.apply(
Expand All @@ -234,8 +238,8 @@ def test_make_tuple_nested(uids: utils.IDGeneratorPool):
),
d,
)(im.ref("a", field_type), im.ref("b", field_type), im.ref("c", field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
# simplify to remove unnecessary make_tuple call
actual_simplified = ct.CollapseTuple.apply(
Expand Down Expand Up @@ -276,8 +280,8 @@ def test_make_tuple_fusion_different_domains(uids: utils.IDGeneratorPool):
im.tuple_get(1, "__fasfop_1"),
)
)
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == expected

Expand Down Expand Up @@ -312,8 +316,12 @@ def test_partial_inline(uids: utils.IDGeneratorPool):
"inp1",
"inp2",
)
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee,
offset_provider_type={"IOff": IDim},
allow_undeclared_symbols=True,
enable_cse=False,
uids=uids,
)
assert actual == expected

Expand All @@ -336,8 +344,8 @@ def test_chained_fusion(uids: utils.IDGeneratorPool):
),
d,
)(im.ref("inp1", field_type), im.ref("inp2", field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == expected

Expand All @@ -353,8 +361,8 @@ def test_inline_as_fieldop_with_list_dtype(uids: utils.IDGeneratorPool):
expected = im.as_fieldop(
im.lambda_("inp")(im.call(im.call("reduce")("plus", 0))(im.deref("inp"))), d
)(im.ref("inp", list_field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == expected

Expand All @@ -364,8 +372,8 @@ def test_inline_into_scan(uids: utils.IDGeneratorPool):
scan = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0)
testee = im.as_fieldop(scan, d)(im.as_fieldop("deref")(im.ref("a", field_type)))
expected = im.as_fieldop(scan, d)(im.ref("a", field_type))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == expected

Expand All @@ -377,8 +385,8 @@ def test_no_inline_into_scan(uids: utils.IDGeneratorPool):
)
scan = im.as_fieldop(scan_stencil, d)(im.ref("a", field_type))
testee = im.as_fieldop(im.lambda_("arg")(im.deref("arg")), d)(scan)
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == testee

Expand All @@ -390,7 +398,7 @@ def test_opage_arg_deduplication(uids: utils.IDGeneratorPool):
im.lambda_("__arg1")(im.plus(im.deref("__arg1"), im.deref("__arg1"))),
d,
)(im.index(IDim))
actual = fasfop.FuseAsFieldOp.apply(
testee, uids=uids, offset_provider_type={}, allow_undeclared_symbols=True
actual = fuse_as_fieldop.FuseAsFieldOp.apply(
testee, offset_provider_type={}, allow_undeclared_symbols=True, enable_cse=False, uids=uids
)
assert actual == expected
Loading