diff --git a/tests/interpreters/test_ematch_interpreter.py b/tests/interpreters/test_ematch_interpreter.py index 419acb2c67..72931e9792 100644 --- a/tests/interpreters/test_ematch_interpreter.py +++ b/tests/interpreters/test_ematch_interpreter.py @@ -1,11 +1,18 @@ +from dataclasses import dataclass +from typing import Any, cast + +import pytest + from xdsl.builder import ImplicitBuilder -from xdsl.dialects import ematch, equivalence, pdl, test -from xdsl.dialects.builtin import ModuleOp, i32 +from xdsl.dialects import arith, ematch, equivalence, pdl, test +from xdsl.dialects.builtin import IntegerAttr, ModuleOp, i32 from xdsl.interpreter import Interpreter from xdsl.interpreters.ematch import EmatchFunctions +from xdsl.interpreters.eqsat_pdl_interp import NonPropagatingDataFlowSolver from xdsl.interpreters.pdl_interp import PDLInterpFunctions -from xdsl.ir import Block, Region +from xdsl.ir import Block, Operation, Region from xdsl.pattern_rewriter import PatternRewriter +from xdsl.utils.exceptions import InterpretationError from xdsl.utils.test_value import create_ssa_value @@ -233,3 +240,399 @@ def test_get_or_create_class_creates_new_class_for_block_arg(): assert isinstance(result, equivalence.ClassOp) assert block_arg in result.operands assert ematch_funcs.eclass_union_find.find(result) is result + + +def test_union_val(): + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + + interpreter.run_op( + ematch.UnionOp( + create_ssa_value(pdl.ValueType()), create_ssa_value(pdl.ValueType()) + ), + (v0, v1), + ) + + # After union, both values should be operands of the same ClassOp + eclass_a = ematch_funcs.get_or_create_class(interpreter, v0) + eclass_b = ematch_funcs.get_or_create_class(interpreter, v1) + assert eclass_a is eclass_b + assert set(eclass_a.operands) == {v0, v1} + + +def test_eclass_union_same_class(): + """Union of the same eclass with itself is a no-op.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + eclass = equivalence.ClassOp(v0, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass) + + result = ematch_funcs.eclass_union(interpreter, eclass, eclass) + assert result is False + assert len(ematch_funcs.worklist) == 0 + + +def test_eclass_union_two_regular_classes(): + """Union of two regular eclasses merges operands and replaces one.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(v0, res_type=i32) + eclass_b = equivalence.ClassOp(v1, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + + result = ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + assert result is True + + canonical = ematch_funcs.eclass_union_find.find(eclass_a) + assert ematch_funcs.eclass_union_find.find(eclass_b) is canonical + assert set(canonical.operands) == {v0, v1} + + +def test_eclass_union_constant_with_regular(): + """Union of ConstantClassOp with regular ClassOp keeps the constant as canonical.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + const_op = arith.ConstantOp(IntegerAttr(1, i32)) + regular_op = test.TestOp(result_types=(i32,)) + + const_eclass = equivalence.ConstantClassOp(const_op.result) + regular_eclass = equivalence.ClassOp(regular_op.results[0], res_type=i32) + + ematch_funcs.eclass_union_find.add(const_eclass) + ematch_funcs.eclass_union_find.add(regular_eclass) + + result = ematch_funcs.eclass_union(interpreter, const_eclass, regular_eclass) + assert result is True + + canonical = ematch_funcs.eclass_union_find.find(const_eclass) + assert isinstance(canonical, equivalence.ConstantClassOp) + assert canonical.value == IntegerAttr(1, i32) + assert set(canonical.operands) == {const_op.result, regular_op.results[0]} + + +def test_eclass_union_regular_with_constant(): + """Union with constant as second argument still keeps constant as canonical.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + regular_op = test.TestOp(result_types=(i32,)) + const_op = arith.ConstantOp(IntegerAttr(42, i32)) + + regular_eclass = equivalence.ClassOp(regular_op.results[0], res_type=i32) + const_eclass = equivalence.ConstantClassOp(const_op.result) + + ematch_funcs.eclass_union_find.add(regular_eclass) + ematch_funcs.eclass_union_find.add(const_eclass) + + result = ematch_funcs.eclass_union(interpreter, regular_eclass, const_eclass) + assert result is True + + canonical = ematch_funcs.eclass_union_find.find(regular_eclass) + assert isinstance(canonical, equivalence.ConstantClassOp) + assert canonical is const_eclass + + +def test_eclass_union_two_same_constants(): + """Union of two ConstantClassOps with the same value succeeds.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + const1 = arith.ConstantOp(IntegerAttr(5, i32)) + const2 = arith.ConstantOp(IntegerAttr(5, i32)) + + const_eclass1 = equivalence.ConstantClassOp(const1.result) + const_eclass2 = equivalence.ConstantClassOp(const2.result) + + ematch_funcs.eclass_union_find.add(const_eclass1) + ematch_funcs.eclass_union_find.add(const_eclass2) + + result = ematch_funcs.eclass_union(interpreter, const_eclass1, const_eclass2) + assert result is True + + canonical = ematch_funcs.eclass_union_find.find(const_eclass1) + assert isinstance(canonical, equivalence.ConstantClassOp) + assert set(canonical.operands) == {const1.result, const2.result} + + +def test_eclass_union_two_different_constants_fails(): + """Union of two ConstantClassOps with different values raises AssertionError.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + const1 = arith.ConstantOp(IntegerAttr(1, i32)) + const2 = arith.ConstantOp(IntegerAttr(2, i32)) + + const_eclass1 = equivalence.ConstantClassOp(const1.result) + const_eclass2 = equivalence.ConstantClassOp(const2.result) + + ematch_funcs.eclass_union_find.add(const_eclass1) + ematch_funcs.eclass_union_find.add(const_eclass2) + + with pytest.raises( + AssertionError, match="Trying to union two different constant eclasses." + ): + ematch_funcs.eclass_union(interpreter, const_eclass1, const_eclass2) + + +def test_eclass_union_removes_uses_from_known_ops(): + """Uses of the replaced eclass are removed from known_ops before replacement.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(v0, res_type=i32) + eclass_b = equivalence.ClassOp(v1, res_type=i32) + + # Create an operation that uses eclass_b's result — it should be + # removed from known_ops when eclass_b is replaced. + user_op = test.TestOp((eclass_b.result,), result_types=(i32,)) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + ematch_funcs.known_ops[user_op] = user_op + + ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + + # user_op used eclass_b.result, so it must have been popped from known_ops + assert user_op not in ematch_funcs.known_ops + + +def test_eclass_union_deduplicates_operands(): + """When the same value is an operand of both eclasses, it appears only once after union.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + shared = test.TestOp(result_types=(i32,)).results[0] + extra = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(shared, res_type=i32) + eclass_b = equivalence.ClassOp(shared, extra, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + + ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + + canonical = ematch_funcs.eclass_union_find.find(eclass_a) + # shared should appear only once + operand_list = list(canonical.operands) + assert operand_list.count(shared) == 1 + assert set(canonical.operands) == {shared, extra} + + +def test_eclass_union_meets_analysis_states(): + """Analysis lattice states are met when eclasses are unioned.""" + from typing_extensions import Self + + from xdsl.analysis.dataflow import DataFlowSolver + from xdsl.analysis.sparse_analysis import ( + AbstractLatticeValue, + Lattice, + SparseForwardDataFlowAnalysis, + ) + + @dataclass(frozen=True) + class TestLatticeValue(AbstractLatticeValue): + value: int + + @classmethod + def initial_value(cls) -> Self: + return cls(0) + + def meet(self, other: "TestLatticeValue") -> "TestLatticeValue": + return TestLatticeValue(min(self.value, other.value)) + + def join(self, other: "TestLatticeValue") -> "TestLatticeValue": + return TestLatticeValue(max(self.value, other.value)) + + class TestLattice(Lattice[TestLatticeValue]): + value_cls = TestLatticeValue + + class TestAnalysis(SparseForwardDataFlowAnalysis[TestLattice]): + def __init__(self, solver: DataFlowSolver): + super().__init__(solver, TestLattice) + + def visit_operation_impl( + self, + op: Operation, + operands: list[TestLattice], + results: list[TestLattice], + ) -> None: + pass + + def set_to_entry_state(self, lattice: TestLattice) -> None: + pass + + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + ctx = PDLInterpFunctions.get_ctx(interpreter) + solver = NonPropagatingDataFlowSolver(ctx) + analysis = TestAnalysis(solver) + ematch_funcs.analyses.append( + cast(SparseForwardDataFlowAnalysis[Lattice[Any]], analysis) + ) + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(v0, res_type=i32) + eclass_b = equivalence.ClassOp(v1, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + + # Set lattice values: a=10, b=3 → meet = min(10, 3) = 3 + lattice_a = analysis.get_lattice_element(eclass_a.result) + lattice_a._value = TestLatticeValue(10) # pyright: ignore[reportPrivateUsage] + lattice_b = analysis.get_lattice_element(eclass_b.result) + lattice_b._value = TestLatticeValue(3) # pyright: ignore[reportPrivateUsage] + + ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + + # The surviving eclass (a) should have the met value + canonical = ematch_funcs.eclass_union_find.find(eclass_a) + result_lattice = analysis.get_lattice_element(canonical.result) + assert result_lattice.value.value == 3 + + +def test_eclass_union_removes_uses_not_in_known_ops(): + """Operations using the replaced eclass that are NOT in known_ops should not cause errors.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(v0, res_type=i32) + eclass_b = equivalence.ClassOp(v1, res_type=i32) + + # Create an operation that uses eclass_b's result but do NOT add it to known_ops + user_op = test.TestOp((eclass_b.result,), result_types=(i32,)) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + + # user_op is intentionally not in known_ops + assert user_op not in ematch_funcs.known_ops + + # Should succeed without error + ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + + # user_op should still not be in known_ops + assert user_op not in ematch_funcs.known_ops + + +def test_union_val_same_value(): + """Union of a value with itself is a no-op.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + + ematch_funcs.union_val(interpreter, v0, v0) + + # No worklist entries should be created + assert not ematch_funcs.worklist + + +def test_union_val_already_same_eclass(): + """Union of two values already in the same eclass is a no-op (via eclass_union).""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass = equivalence.ClassOp(v0, v1, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass) + + ematch_funcs.union_val(interpreter, v0, v1) + + # Both already in the same eclass, so no worklist entries + assert not ematch_funcs.worklist + + +def test_run_union_operation_and_value_range(): + """Union of an operation with a value range merges results pairwise.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + op = test.TestOp(result_types=(i32, i32)) + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + + interpreter.run_op( + ematch.UnionOp( + create_ssa_value(pdl.OperationType()), + create_ssa_value(pdl.RangeType(pdl.ValueType())), + ), + (op, (v0, v1)), + ) + + eclass_r0 = ematch_funcs.get_or_create_class(interpreter, op.results[0]) + eclass_v0 = ematch_funcs.get_or_create_class(interpreter, v0) + assert ematch_funcs.eclass_union_find.find( + eclass_r0 + ) is ematch_funcs.eclass_union_find.find(eclass_v0) + + eclass_r1 = ematch_funcs.get_or_create_class(interpreter, op.results[1]) + eclass_v1 = ematch_funcs.get_or_create_class(interpreter, v1) + assert ematch_funcs.eclass_union_find.find( + eclass_r1 + ) is ematch_funcs.eclass_union_find.find(eclass_v1) + + +def test_run_union_two_value_ranges(): + """Union of two value ranges merges values pairwise.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + v2 = test.TestOp(result_types=(i32,)).results[0] + v3 = test.TestOp(result_types=(i32,)).results[0] + + interpreter.run_op( + ematch.UnionOp( + create_ssa_value(pdl.RangeType(pdl.ValueType())), + create_ssa_value(pdl.RangeType(pdl.ValueType())), + ), + ((v0, v1), (v2, v3)), + ) + + eclass_v0 = ematch_funcs.get_or_create_class(interpreter, v0) + eclass_v2 = ematch_funcs.get_or_create_class(interpreter, v2) + assert ematch_funcs.eclass_union_find.find( + eclass_v0 + ) is ematch_funcs.eclass_union_find.find(eclass_v2) + + eclass_v1 = ematch_funcs.get_or_create_class(interpreter, v1) + eclass_v3 = ematch_funcs.get_or_create_class(interpreter, v3) + assert ematch_funcs.eclass_union_find.find( + eclass_v1 + ) is ematch_funcs.eclass_union_find.find(eclass_v3) + + +def test_run_union_unsupported_types(): + """Union with unsupported argument types raises InterpretationError.""" + interpreter, _ematch_funcs, _block = _make_interpreter_with_rewriter() + + with pytest.raises(InterpretationError, match="unsupported argument types"): + interpreter.run_op( + ematch.UnionOp( + create_ssa_value(pdl.ValueType()), + create_ssa_value(pdl.ValueType()), + ), + ("not_a_value", 42), + ) diff --git a/xdsl/interpreters/ematch.py b/xdsl/interpreters/ematch.py index 432554cc8a..cabcb5c88c 100644 --- a/xdsl/interpreters/ematch.py +++ b/xdsl/interpreters/ematch.py @@ -1,12 +1,18 @@ +from collections.abc import Sequence from dataclasses import dataclass, field from typing import Any +from ordered_set import OrderedSet + +from xdsl.analysis.sparse_analysis import Lattice, SparseForwardDataFlowAnalysis from xdsl.dialects import ematch, equivalence from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls from xdsl.interpreters.pdl_interp import PDLInterpFunctions -from xdsl.ir import Block, OpResult, SSAValue +from xdsl.ir import Block, Operation, OpResult, SSAValue from xdsl.rewriter import InsertPoint +from xdsl.transforms.common_subexpression_elimination import KnownOps from xdsl.utils.disjoint_set import DisjointSet +from xdsl.utils.exceptions import InterpretationError from xdsl.utils.hints import isa @@ -15,11 +21,28 @@ class EmatchFunctions(InterpreterFunctions): """Interpreter functions for PDL patterns operating on e-graphs.""" + known_ops: KnownOps = field(default_factory=KnownOps) + """Used for hashconsing operations. When new operations are created, if they are identical to an existing operation, + the existing operation is reused instead of creating a new one.""" + eclass_union_find: DisjointSet[equivalence.AnyClassOp] = field( default_factory=lambda: DisjointSet[equivalence.AnyClassOp]() ) """Union-find structure tracking which e-classes are equivalent and should be merged.""" + worklist: list[equivalence.AnyClassOp] = field( + default_factory=list[equivalence.AnyClassOp] + ) + """Worklist of e-classes that need to be processed for matching.""" + + analyses: list[SparseForwardDataFlowAnalysis[Lattice[Any]]] = field( + default_factory=lambda: [] + ) + """The sparse forward analyses to be run during equality saturation. + These must be registered with a NonPropagatingDataFlowSolver where `propagate` is False. + This way, state propagation is handled purely by the equality saturation logic. + """ + @impl(ematch.GetClassValsOp) def run_get_class_vals( self, @@ -169,3 +192,113 @@ def get_or_create_class( ) return eclass_op + + def eclass_union( + self, + interpreter: Interpreter, + a: equivalence.AnyClassOp, + b: equivalence.AnyClassOp, + ) -> bool: + """Unions two eclasses, merging their operands and results. + Returns True if the eclasses were merged, False if they were already the same.""" + a = self.eclass_union_find.find(a) + b = self.eclass_union_find.find(b) + + if a == b: + return False + + # Meet the analysis states of the two e-classes + for analysis in self.analyses: + a_lattice = analysis.get_lattice_element(a.result) + b_lattice = analysis.get_lattice_element(b.result) + a_lattice.meet(b_lattice) + + if isinstance(a, equivalence.ConstantClassOp): + if isinstance(b, equivalence.ConstantClassOp): + assert a.value == b.value, ( + "Trying to union two different constant eclasses.", + ) + to_keep, to_replace = a, b + self.eclass_union_find.union_left(to_keep, to_replace) + elif isinstance(b, equivalence.ConstantClassOp): + to_keep, to_replace = b, a + self.eclass_union_find.union_left(to_keep, to_replace) + else: + self.eclass_union_find.union( + a, + b, + ) + to_keep = self.eclass_union_find.find(a) + to_replace = b if to_keep is a else a + # Operands need to be deduplicated because it can happen the same operand was + # used by different parent eclasses after their children were merged: + new_operands = OrderedSet(to_keep.operands) + new_operands.update(to_replace.operands) + to_keep.operands = new_operands + + for use in to_replace.result.uses: + # uses are removed from the hashcons before the replacement is carried out. + # (because the replacement changes the operations which means we cannot find them in the hashcons anymore) + if use.operation in self.known_ops: + self.known_ops.pop(use.operation) + + rewriter = PDLInterpFunctions.get_rewriter(interpreter) + rewriter.replace_op(to_replace, new_ops=[], new_results=to_keep.results) + return True + + def union_val(self, interpreter: Interpreter, a: SSAValue, b: SSAValue) -> None: + """ + Union two values into the same equivalence class. + """ + if a == b: + return + + eclass_a = self.get_or_create_class(interpreter, a) + eclass_b = self.get_or_create_class(interpreter, b) + + if self.eclass_union(interpreter, eclass_a, eclass_b): + self.worklist.append(eclass_a) + + @impl(ematch.UnionOp) + def run_union( + self, + interpreter: Interpreter, + op: ematch.UnionOp, + args: tuple[Any, ...], + ) -> tuple[Any, ...]: + """ + Merge two values, an operation and a value range, or two value ranges + into equivalence class(es). + + Supported operand type combinations: + - (value, value): merge two values + - (operation, range): merge operation results with values + - (range, range): merge two value ranges + """ + assert len(args) == 2 + lhs, rhs = args + + if isa(lhs, SSAValue) and isa(rhs, SSAValue): + # (Value, Value) case + self.union_val(interpreter, lhs, rhs) + + elif isinstance(lhs, Operation) and isa(rhs, Sequence[SSAValue]): + # (Operation, ValueRange) case + assert len(lhs.results) == len(rhs), ( + "Operation result count must match value range size" + ) + for result, val in zip(lhs.results, rhs, strict=True): + self.union_val(interpreter, result, val) + + elif isa(lhs, Sequence[SSAValue]) and isa(rhs, Sequence[SSAValue]): + # (ValueRange, ValueRange) case + assert len(lhs) == len(rhs), "Value ranges must have equal size" + for val_lhs, val_rhs in zip(lhs, rhs, strict=True): + self.union_val(interpreter, val_lhs, val_rhs) + + else: + raise InterpretationError( + f"union: unsupported argument types: {type(lhs)}, {type(rhs)}" + ) + + return ()