From 8698b9c7b23e9a299e74d2e831cf88abde84919a Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 7 Mar 2026 12:28:48 +0100 Subject: [PATCH 1/9] add worklist field --- xdsl/interpreters/ematch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xdsl/interpreters/ematch.py b/xdsl/interpreters/ematch.py index 432554cc8a..c6be1df44f 100644 --- a/xdsl/interpreters/ematch.py +++ b/xdsl/interpreters/ematch.py @@ -20,6 +20,11 @@ class EmatchFunctions(InterpreterFunctions): ) """Union-find structure tracking which e-classes are equivalent and should be merged.""" + worklist: list[equivalence.AnyClassOp] = field( + default_factory=list[equivalence.AnyClassOp] + ) + """Worklist of e-classes that need to be processed for matching.""" + @impl(ematch.GetClassValsOp) def run_get_class_vals( self, From 7327e30b7675e5e957c483b53c89f61bea6285f1 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 7 Mar 2026 12:30:27 +0100 Subject: [PATCH 2/9] add analyses field --- xdsl/interpreters/ematch.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xdsl/interpreters/ematch.py b/xdsl/interpreters/ematch.py index c6be1df44f..b17c561335 100644 --- a/xdsl/interpreters/ematch.py +++ b/xdsl/interpreters/ematch.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import Any +from xdsl.analysis.sparse_analysis import Lattice, SparseForwardDataFlowAnalysis from xdsl.dialects import ematch, equivalence from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls from xdsl.interpreters.pdl_interp import PDLInterpFunctions @@ -25,6 +26,14 @@ class EmatchFunctions(InterpreterFunctions): ) """Worklist of e-classes that need to be processed for matching.""" + analyses: list[SparseForwardDataFlowAnalysis[Lattice[Any]]] = field( + default_factory=lambda: [] + ) + """The sparse forward analyses to be run during equality saturation. + These must be registered with a NonPropagatingDataFlowSolver where `propagate` is False. + This way, state propagation is handled purely by the equality saturation logic. + """ + @impl(ematch.GetClassValsOp) def run_get_class_vals( self, From 88fa119379912d886e3dc6204ac1b4fbeecc6736 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 7 Mar 2026 12:46:31 +0100 Subject: [PATCH 3/9] add known_ops field + related utilities --- xdsl/interpreters/ematch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xdsl/interpreters/ematch.py b/xdsl/interpreters/ematch.py index b17c561335..54e9b72286 100644 --- a/xdsl/interpreters/ematch.py +++ b/xdsl/interpreters/ematch.py @@ -7,6 +7,7 @@ from xdsl.interpreters.pdl_interp import PDLInterpFunctions from xdsl.ir import Block, OpResult, SSAValue from xdsl.rewriter import InsertPoint +from xdsl.transforms.common_subexpression_elimination import KnownOps from xdsl.utils.disjoint_set import DisjointSet from xdsl.utils.hints import isa @@ -16,6 +17,10 @@ class EmatchFunctions(InterpreterFunctions): """Interpreter functions for PDL patterns operating on e-graphs.""" + known_ops: KnownOps = field(default_factory=KnownOps) + """Used for hashconsing operations. When new operations are created, if they are identical to an existing operation, + the existing operation is reused instead of creating a new one.""" + eclass_union_find: DisjointSet[equivalence.AnyClassOp] = field( default_factory=lambda: DisjointSet[equivalence.AnyClassOp]() ) From 730b6201ac5d6a66c6d3b2d0b34b831966e0333c Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 1 Mar 2026 10:55:08 +0100 Subject: [PATCH 4/9] ematch: add eclass_union --- xdsl/interpreters/ematch.py | 55 +++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/xdsl/interpreters/ematch.py b/xdsl/interpreters/ematch.py index 54e9b72286..ac7113df0e 100644 --- a/xdsl/interpreters/ematch.py +++ b/xdsl/interpreters/ematch.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field from typing import Any +from ordered_set import OrderedSet + from xdsl.analysis.sparse_analysis import Lattice, SparseForwardDataFlowAnalysis from xdsl.dialects import ematch, equivalence from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls @@ -188,3 +190,56 @@ def get_or_create_class( ) return eclass_op + + def eclass_union( + self, + interpreter: Interpreter, + a: equivalence.AnyClassOp, + b: equivalence.AnyClassOp, + ) -> bool: + """Unions two eclasses, merging their operands and results. + Returns True if the eclasses were merged, False if they were already the same.""" + a = self.eclass_union_find.find(a) + b = self.eclass_union_find.find(b) + + if a == b: + return False + + # Meet the analysis states of the two e-classes + for analysis in self.analyses: + a_lattice = analysis.get_lattice_element(a.result) + b_lattice = analysis.get_lattice_element(b.result) + a_lattice.meet(b_lattice) + + if isinstance(a, equivalence.ConstantClassOp): + if isinstance(b, equivalence.ConstantClassOp): + assert a.value == b.value, ( + "Trying to union two different constant eclasses.", + ) + to_keep, to_replace = a, b + self.eclass_union_find.union_left(to_keep, to_replace) + elif isinstance(b, equivalence.ConstantClassOp): + to_keep, to_replace = b, a + self.eclass_union_find.union_left(to_keep, to_replace) + else: + self.eclass_union_find.union( + a, + b, + ) + to_keep = self.eclass_union_find.find(a) + to_replace = b if to_keep is a else a + # Operands need to be deduplicated because it can happen the same operand was + # used by different parent eclasses after their children were merged: + new_operands = OrderedSet(to_keep.operands) + new_operands.update(to_replace.operands) + to_keep.operands = new_operands + + for use in to_replace.result.uses: + # uses are removed from the hashcons before the replacement is carried out. + # (because the replacement changes the operations which means we cannot find them in the hashcons anymore) + if use.operation in self.known_ops: + self.known_ops.pop(use.operation) + + rewriter = PDLInterpFunctions.get_rewriter(interpreter) + rewriter.replace_op(to_replace, new_ops=[], new_results=to_keep.results) + return True From b960d4a2eaa68eb9fd8f17c15b649e2ec6f2ce8e Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 1 Mar 2026 10:54:42 +0100 Subject: [PATCH 5/9] ematch: add union_val, run_union import Operation --- xdsl/interpreters/ematch.py | 61 ++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/xdsl/interpreters/ematch.py b/xdsl/interpreters/ematch.py index ac7113df0e..cabcb5c88c 100644 --- a/xdsl/interpreters/ematch.py +++ b/xdsl/interpreters/ematch.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from dataclasses import dataclass, field from typing import Any @@ -7,10 +8,11 @@ from xdsl.dialects import ematch, equivalence from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls from xdsl.interpreters.pdl_interp import PDLInterpFunctions -from xdsl.ir import Block, OpResult, SSAValue +from xdsl.ir import Block, Operation, OpResult, SSAValue from xdsl.rewriter import InsertPoint from xdsl.transforms.common_subexpression_elimination import KnownOps from xdsl.utils.disjoint_set import DisjointSet +from xdsl.utils.exceptions import InterpretationError from xdsl.utils.hints import isa @@ -243,3 +245,60 @@ def eclass_union( rewriter = PDLInterpFunctions.get_rewriter(interpreter) rewriter.replace_op(to_replace, new_ops=[], new_results=to_keep.results) return True + + def union_val(self, interpreter: Interpreter, a: SSAValue, b: SSAValue) -> None: + """ + Union two values into the same equivalence class. + """ + if a == b: + return + + eclass_a = self.get_or_create_class(interpreter, a) + eclass_b = self.get_or_create_class(interpreter, b) + + if self.eclass_union(interpreter, eclass_a, eclass_b): + self.worklist.append(eclass_a) + + @impl(ematch.UnionOp) + def run_union( + self, + interpreter: Interpreter, + op: ematch.UnionOp, + args: tuple[Any, ...], + ) -> tuple[Any, ...]: + """ + Merge two values, an operation and a value range, or two value ranges + into equivalence class(es). + + Supported operand type combinations: + - (value, value): merge two values + - (operation, range): merge operation results with values + - (range, range): merge two value ranges + """ + assert len(args) == 2 + lhs, rhs = args + + if isa(lhs, SSAValue) and isa(rhs, SSAValue): + # (Value, Value) case + self.union_val(interpreter, lhs, rhs) + + elif isinstance(lhs, Operation) and isa(rhs, Sequence[SSAValue]): + # (Operation, ValueRange) case + assert len(lhs.results) == len(rhs), ( + "Operation result count must match value range size" + ) + for result, val in zip(lhs.results, rhs, strict=True): + self.union_val(interpreter, result, val) + + elif isa(lhs, Sequence[SSAValue]) and isa(rhs, Sequence[SSAValue]): + # (ValueRange, ValueRange) case + assert len(lhs) == len(rhs), "Value ranges must have equal size" + for val_lhs, val_rhs in zip(lhs, rhs, strict=True): + self.union_val(interpreter, val_lhs, val_rhs) + + else: + raise InterpretationError( + f"union: unsupported argument types: {type(lhs)}, {type(rhs)}" + ) + + return () From ad978af77c6e00612d1e5f1bb0de43e2257ee2bf Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:50:34 +0100 Subject: [PATCH 6/9] ematch: test union_val --- tests/interpreters/test_ematch_interpreter.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/interpreters/test_ematch_interpreter.py b/tests/interpreters/test_ematch_interpreter.py index 419acb2c67..730d747a70 100644 --- a/tests/interpreters/test_ematch_interpreter.py +++ b/tests/interpreters/test_ematch_interpreter.py @@ -233,3 +233,24 @@ def test_get_or_create_class_creates_new_class_for_block_arg(): assert isinstance(result, equivalence.ClassOp) assert block_arg in result.operands assert ematch_funcs.eclass_union_find.find(result) is result + + +def test_union_val(): + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + + interpreter.run_op( + ematch.UnionOp( + create_ssa_value(pdl.ValueType()), create_ssa_value(pdl.ValueType()) + ), + (v0, v1), + ) + + # After union, both values should be operands of the same ClassOp + eclass_a = ematch_funcs.get_or_create_class(interpreter, v0) + eclass_b = ematch_funcs.get_or_create_class(interpreter, v1) + assert eclass_a is eclass_b + assert set(eclass_a.operands) == {v0, v1} From 8c0aa12dc0c13bcf85ff343b4775511150c6b239 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sat, 7 Mar 2026 21:55:03 +0100 Subject: [PATCH 7/9] more tests --- tests/interpreters/test_ematch_interpreter.py | 256 +++++++++++++++++- 1 file changed, 253 insertions(+), 3 deletions(-) diff --git a/tests/interpreters/test_ematch_interpreter.py b/tests/interpreters/test_ematch_interpreter.py index 730d747a70..09ede799d6 100644 --- a/tests/interpreters/test_ematch_interpreter.py +++ b/tests/interpreters/test_ematch_interpreter.py @@ -1,10 +1,16 @@ +from dataclasses import dataclass +from typing import Any, cast + +import pytest + from xdsl.builder import ImplicitBuilder -from xdsl.dialects import ematch, equivalence, pdl, test -from xdsl.dialects.builtin import ModuleOp, i32 +from xdsl.dialects import arith, ematch, equivalence, pdl, test +from xdsl.dialects.builtin import IntegerAttr, ModuleOp, i32 from xdsl.interpreter import Interpreter from xdsl.interpreters.ematch import EmatchFunctions +from xdsl.interpreters.eqsat_pdl_interp import NonPropagatingDataFlowSolver from xdsl.interpreters.pdl_interp import PDLInterpFunctions -from xdsl.ir import Block, Region +from xdsl.ir import Block, Operation, Region from xdsl.pattern_rewriter import PatternRewriter from xdsl.utils.test_value import create_ssa_value @@ -254,3 +260,247 @@ def test_union_val(): eclass_b = ematch_funcs.get_or_create_class(interpreter, v1) assert eclass_a is eclass_b assert set(eclass_a.operands) == {v0, v1} + + +def test_eclass_union_same_class(): + """Union of the same eclass with itself is a no-op.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + eclass = equivalence.ClassOp(v0, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass) + + result = ematch_funcs.eclass_union(interpreter, eclass, eclass) + assert result is False + assert len(ematch_funcs.worklist) == 0 + + +def test_eclass_union_two_regular_classes(): + """Union of two regular eclasses merges operands and replaces one.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(v0, res_type=i32) + eclass_b = equivalence.ClassOp(v1, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + + result = ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + assert result is True + + canonical = ematch_funcs.eclass_union_find.find(eclass_a) + assert ematch_funcs.eclass_union_find.find(eclass_b) is canonical + assert set(canonical.operands) == {v0, v1} + + +def test_eclass_union_constant_with_regular(): + """Union of ConstantClassOp with regular ClassOp keeps the constant as canonical.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + const_op = arith.ConstantOp(IntegerAttr(1, i32)) + regular_op = test.TestOp(result_types=(i32,)) + + const_eclass = equivalence.ConstantClassOp(const_op.result) + regular_eclass = equivalence.ClassOp(regular_op.results[0], res_type=i32) + + ematch_funcs.eclass_union_find.add(const_eclass) + ematch_funcs.eclass_union_find.add(regular_eclass) + + result = ematch_funcs.eclass_union(interpreter, const_eclass, regular_eclass) + assert result is True + + canonical = ematch_funcs.eclass_union_find.find(const_eclass) + assert isinstance(canonical, equivalence.ConstantClassOp) + assert canonical.value == IntegerAttr(1, i32) + assert set(canonical.operands) == {const_op.result, regular_op.results[0]} + + +def test_eclass_union_regular_with_constant(): + """Union with constant as second argument still keeps constant as canonical.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + regular_op = test.TestOp(result_types=(i32,)) + const_op = arith.ConstantOp(IntegerAttr(42, i32)) + + regular_eclass = equivalence.ClassOp(regular_op.results[0], res_type=i32) + const_eclass = equivalence.ConstantClassOp(const_op.result) + + ematch_funcs.eclass_union_find.add(regular_eclass) + ematch_funcs.eclass_union_find.add(const_eclass) + + result = ematch_funcs.eclass_union(interpreter, regular_eclass, const_eclass) + assert result is True + + canonical = ematch_funcs.eclass_union_find.find(regular_eclass) + assert isinstance(canonical, equivalence.ConstantClassOp) + assert canonical is const_eclass + + +def test_eclass_union_two_same_constants(): + """Union of two ConstantClassOps with the same value succeeds.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + const1 = arith.ConstantOp(IntegerAttr(5, i32)) + const2 = arith.ConstantOp(IntegerAttr(5, i32)) + + const_eclass1 = equivalence.ConstantClassOp(const1.result) + const_eclass2 = equivalence.ConstantClassOp(const2.result) + + ematch_funcs.eclass_union_find.add(const_eclass1) + ematch_funcs.eclass_union_find.add(const_eclass2) + + result = ematch_funcs.eclass_union(interpreter, const_eclass1, const_eclass2) + assert result is True + + canonical = ematch_funcs.eclass_union_find.find(const_eclass1) + assert isinstance(canonical, equivalence.ConstantClassOp) + assert set(canonical.operands) == {const1.result, const2.result} + + +def test_eclass_union_two_different_constants_fails(): + """Union of two ConstantClassOps with different values raises AssertionError.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + const1 = arith.ConstantOp(IntegerAttr(1, i32)) + const2 = arith.ConstantOp(IntegerAttr(2, i32)) + + const_eclass1 = equivalence.ConstantClassOp(const1.result) + const_eclass2 = equivalence.ConstantClassOp(const2.result) + + ematch_funcs.eclass_union_find.add(const_eclass1) + ematch_funcs.eclass_union_find.add(const_eclass2) + + with pytest.raises( + AssertionError, match="Trying to union two different constant eclasses." + ): + ematch_funcs.eclass_union(interpreter, const_eclass1, const_eclass2) + + +def test_eclass_union_removes_uses_from_known_ops(): + """Uses of the replaced eclass are removed from known_ops before replacement.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(v0, res_type=i32) + eclass_b = equivalence.ClassOp(v1, res_type=i32) + + # Create an operation that uses eclass_b's result — it should be + # removed from known_ops when eclass_b is replaced. + user_op = test.TestOp((eclass_b.result,), result_types=(i32,)) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + ematch_funcs.known_ops[user_op] = user_op + + ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + + # user_op used eclass_b.result, so it must have been popped from known_ops + assert user_op not in ematch_funcs.known_ops + + +def test_eclass_union_deduplicates_operands(): + """When the same value is an operand of both eclasses, it appears only once after union.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + shared = test.TestOp(result_types=(i32,)).results[0] + extra = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(shared, res_type=i32) + eclass_b = equivalence.ClassOp(shared, extra, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + + ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + + canonical = ematch_funcs.eclass_union_find.find(eclass_a) + # shared should appear only once + operand_list = list(canonical.operands) + assert operand_list.count(shared) == 1 + assert set(canonical.operands) == {shared, extra} + + +def test_eclass_union_meets_analysis_states(): + """Analysis lattice states are met when eclasses are unioned.""" + from typing_extensions import Self + + from xdsl.analysis.dataflow import DataFlowSolver + from xdsl.analysis.sparse_analysis import ( + AbstractLatticeValue, + Lattice, + SparseForwardDataFlowAnalysis, + ) + + @dataclass(frozen=True) + class TestLatticeValue(AbstractLatticeValue): + value: int + + @classmethod + def initial_value(cls) -> Self: + return cls(0) + + def meet(self, other: "TestLatticeValue") -> "TestLatticeValue": + return TestLatticeValue(min(self.value, other.value)) + + def join(self, other: "TestLatticeValue") -> "TestLatticeValue": + return TestLatticeValue(max(self.value, other.value)) + + class TestLattice(Lattice[TestLatticeValue]): + value_cls = TestLatticeValue + + class TestAnalysis(SparseForwardDataFlowAnalysis[TestLattice]): + def __init__(self, solver: DataFlowSolver): + super().__init__(solver, TestLattice) + + def visit_operation_impl( + self, + op: Operation, + operands: list[TestLattice], + results: list[TestLattice], + ) -> None: + pass + + def set_to_entry_state(self, lattice: TestLattice) -> None: + pass + + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + ctx = PDLInterpFunctions.get_ctx(interpreter) + solver = NonPropagatingDataFlowSolver(ctx) + analysis = TestAnalysis(solver) + ematch_funcs.analyses.append( + cast(SparseForwardDataFlowAnalysis[Lattice[Any]], analysis) + ) + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(v0, res_type=i32) + eclass_b = equivalence.ClassOp(v1, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + + # Set lattice values: a=10, b=3 → meet = min(10, 3) = 3 + lattice_a = analysis.get_lattice_element(eclass_a.result) + lattice_a._value = TestLatticeValue(10) # pyright: ignore[reportPrivateUsage] + lattice_b = analysis.get_lattice_element(eclass_b.result) + lattice_b._value = TestLatticeValue(3) # pyright: ignore[reportPrivateUsage] + + ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + + # The surviving eclass (a) should have the met value + canonical = ematch_funcs.eclass_union_find.find(eclass_a) + result_lattice = analysis.get_lattice_element(canonical.result) + assert result_lattice.value.value == 3 From 6a40079e1323fc1c3df379d52061d292f3a25b12 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 8 Mar 2026 16:46:49 +0100 Subject: [PATCH 8/9] better coverage --- tests/interpreters/test_ematch_interpreter.py | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/tests/interpreters/test_ematch_interpreter.py b/tests/interpreters/test_ematch_interpreter.py index 09ede799d6..2a90a218ec 100644 --- a/tests/interpreters/test_ematch_interpreter.py +++ b/tests/interpreters/test_ematch_interpreter.py @@ -12,6 +12,7 @@ from xdsl.interpreters.pdl_interp import PDLInterpFunctions from xdsl.ir import Block, Operation, Region from xdsl.pattern_rewriter import PatternRewriter +from xdsl.utils.exceptions import InterpretationError from xdsl.utils.test_value import create_ssa_value @@ -504,3 +505,134 @@ def set_to_entry_state(self, lattice: TestLattice) -> None: canonical = ematch_funcs.eclass_union_find.find(eclass_a) result_lattice = analysis.get_lattice_element(canonical.result) assert result_lattice.value.value == 3 + + +def test_eclass_union_removes_uses_not_in_known_ops(): + """Operations using the replaced eclass that are NOT in known_ops should not cause errors.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass_a = equivalence.ClassOp(v0, res_type=i32) + eclass_b = equivalence.ClassOp(v1, res_type=i32) + + # Create an operation that uses eclass_b's result but do NOT add it to known_ops + user_op = test.TestOp((eclass_b.result,), result_types=(i32,)) + + ematch_funcs.eclass_union_find.add(eclass_a) + ematch_funcs.eclass_union_find.add(eclass_b) + + # user_op is intentionally not in known_ops + assert user_op not in ematch_funcs.known_ops + + # Should succeed without error + ematch_funcs.eclass_union(interpreter, eclass_a, eclass_b) + + # user_op should still not be in known_ops + assert user_op not in ematch_funcs.known_ops + + +def test_union_val_same_value(): + """Union of a value with itself is a no-op.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + + ematch_funcs.union_val(interpreter, v0, v0) + + # No worklist entries should be created + assert len(ematch_funcs.worklist) == 0 + + +def test_union_val_already_same_eclass(): + """Union of two values already in the same eclass is a no-op (via eclass_union).""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + eclass = equivalence.ClassOp(v0, v1, res_type=i32) + + ematch_funcs.eclass_union_find.add(eclass) + + ematch_funcs.union_val(interpreter, v0, v1) + + # Both already in the same eclass, so no worklist entries + assert len(ematch_funcs.worklist) == 0 + + +def test_run_union_operation_and_value_range(): + """Union of an operation with a value range merges results pairwise.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + op = test.TestOp(result_types=(i32, i32)) + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + + interpreter.run_op( + ematch.UnionOp( + create_ssa_value(pdl.OperationType()), + create_ssa_value(pdl.RangeType(pdl.ValueType())), + ), + (op, (v0, v1)), + ) + + eclass_r0 = ematch_funcs.get_or_create_class(interpreter, op.results[0]) + eclass_v0 = ematch_funcs.get_or_create_class(interpreter, v0) + assert ematch_funcs.eclass_union_find.find( + eclass_r0 + ) is ematch_funcs.eclass_union_find.find(eclass_v0) + + eclass_r1 = ematch_funcs.get_or_create_class(interpreter, op.results[1]) + eclass_v1 = ematch_funcs.get_or_create_class(interpreter, v1) + assert ematch_funcs.eclass_union_find.find( + eclass_r1 + ) is ematch_funcs.eclass_union_find.find(eclass_v1) + + +def test_run_union_two_value_ranges(): + """Union of two value ranges merges values pairwise.""" + interpreter, ematch_funcs, block = _make_interpreter_with_rewriter() + + with ImplicitBuilder(block): + v0 = test.TestOp(result_types=(i32,)).results[0] + v1 = test.TestOp(result_types=(i32,)).results[0] + v2 = test.TestOp(result_types=(i32,)).results[0] + v3 = test.TestOp(result_types=(i32,)).results[0] + + interpreter.run_op( + ematch.UnionOp( + create_ssa_value(pdl.RangeType(pdl.ValueType())), + create_ssa_value(pdl.RangeType(pdl.ValueType())), + ), + ((v0, v1), (v2, v3)), + ) + + eclass_v0 = ematch_funcs.get_or_create_class(interpreter, v0) + eclass_v2 = ematch_funcs.get_or_create_class(interpreter, v2) + assert ematch_funcs.eclass_union_find.find( + eclass_v0 + ) is ematch_funcs.eclass_union_find.find(eclass_v2) + + eclass_v1 = ematch_funcs.get_or_create_class(interpreter, v1) + eclass_v3 = ematch_funcs.get_or_create_class(interpreter, v3) + assert ematch_funcs.eclass_union_find.find( + eclass_v1 + ) is ematch_funcs.eclass_union_find.find(eclass_v3) + + +def test_run_union_unsupported_types(): + """Union with unsupported argument types raises InterpretationError.""" + interpreter, _ematch_funcs, _block = _make_interpreter_with_rewriter() + + with pytest.raises(InterpretationError, match="unsupported argument types"): + interpreter.run_op( + ematch.UnionOp( + create_ssa_value(pdl.ValueType()), + create_ssa_value(pdl.ValueType()), + ), + ("not_a_value", 42), + ) From 603f5cc44d4bc7b6a019f753494db4cb4029fd98 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Sun, 8 Mar 2026 17:51:58 +0100 Subject: [PATCH 9/9] Apply suggestions from code review Co-authored-by: Sasha Lopoukhine --- tests/interpreters/test_ematch_interpreter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/interpreters/test_ematch_interpreter.py b/tests/interpreters/test_ematch_interpreter.py index 2a90a218ec..72931e9792 100644 --- a/tests/interpreters/test_ematch_interpreter.py +++ b/tests/interpreters/test_ematch_interpreter.py @@ -543,7 +543,7 @@ def test_union_val_same_value(): ematch_funcs.union_val(interpreter, v0, v0) # No worklist entries should be created - assert len(ematch_funcs.worklist) == 0 + assert not ematch_funcs.worklist def test_union_val_already_same_eclass(): @@ -560,7 +560,7 @@ def test_union_val_already_same_eclass(): ematch_funcs.union_val(interpreter, v0, v1) # Both already in the same eclass, so no worklist entries - assert len(ematch_funcs.worklist) == 0 + assert not ematch_funcs.worklist def test_run_union_operation_and_value_range():