diff --git a/src/irx/builders/llvmliteir/core.py b/src/irx/builders/llvmliteir/core.py index aed37c43..4662a480 100644 --- a/src/irx/builders/llvmliteir/core.py +++ b/src/irx/builders/llvmliteir/core.py @@ -577,9 +577,26 @@ def _unify_numeric_operands( f"{lhs.type.count} vs {rhs.type.count}" ) if lhs.type.element != rhs.type.element: - raise Exception( - "Vector element type mismatch: " - f"{lhs.type.element} vs {rhs.type.element}" + lhs_elem = lhs.type.element + rhs_elem = rhs.type.element + lhs_is_fp = is_fp_type(lhs_elem) + rhs_is_fp = is_fp_type(rhs_elem) + + if lhs_is_fp or rhs_is_fp: + fp_candidates = [ + t for t in (lhs_elem, rhs_elem) if is_fp_type(t) + ] + target_scalar_ty = self._select_float_type(fp_candidates) + else: + lhs_w = getattr(lhs_elem, "width", 0) + rhs_w = getattr(rhs_elem, "width", 0) + target_scalar_ty = ir.IntType(max(lhs_w, rhs_w, 1)) + + lhs = self._cast_value_to_type( + lhs, target_scalar_ty, unsigned=unsigned + ) + rhs = self._cast_value_to_type( + rhs, target_scalar_ty, unsigned=unsigned ) return lhs, rhs diff --git a/src/irx/builders/llvmliteir/visitors/binary_ops.py b/src/irx/builders/llvmliteir/visitors/binary_ops.py index 1785dde4..22094b23 100644 --- a/src/irx/builders/llvmliteir/visitors/binary_ops.py +++ b/src/irx/builders/llvmliteir/visitors/binary_ops.py @@ -182,6 +182,8 @@ def _emit_vector_mul( node: MulBinOp, llvm_lhs: ir.Value, llvm_rhs: ir.Value, + *, + unsigned: bool = False, ) -> ir.Value | None: """ title: Emit vector mul. @@ -192,6 +194,8 @@ def _emit_vector_mul( type: ir.Value llvm_rhs: type: ir.Value + unsigned: + type: bool returns: type: ir.Value | None """ @@ -208,11 +212,12 @@ def _emit_vector_mul( llvm_fma_rhs = safe_pop(self.result_stack) if llvm_fma_rhs is None: raise Exception("FMA requires a valid third operand") - if llvm_fma_rhs.type != llvm_lhs.type: - raise Exception( - f"FMA operand type mismatch: " - f"{llvm_lhs.type} vs {llvm_fma_rhs.type}" - ) + llvm_lhs, llvm_fma_rhs = self._unify_numeric_operands( + llvm_lhs, llvm_fma_rhs, unsigned=unsigned + ) + llvm_rhs, llvm_fma_rhs = self._unify_numeric_operands( + llvm_rhs, llvm_fma_rhs, unsigned=unsigned + ) prev_fast_math = self._fast_math_enabled if set_fast: self.set_fast_math(True) @@ -306,7 +311,10 @@ def _emit_ordered_compare( returns: type: ir.Value """ - if is_fp_type(llvm_lhs.type): + scalar_ty = ( + llvm_lhs.type.element if is_vector(llvm_lhs) else llvm_lhs.type + ) + if is_fp_type(scalar_ty): return self._llvm.ir_builder.fcmp_ordered( op_code, llvm_lhs, @@ -432,9 +440,11 @@ def visit(self, node: MulBinOp) -> None: node: type: MulBinOp """ - llvm_lhs, llvm_rhs, _unsigned = self._load_binary_operands(node) + llvm_lhs, llvm_rhs, unsigned = self._load_binary_operands(node) - vector_result = self._emit_vector_mul(node, llvm_lhs, llvm_rhs) + vector_result = self._emit_vector_mul( + node, llvm_lhs, llvm_rhs, unsigned=unsigned + ) if vector_result is not None: self.result_stack.append(vector_result) return @@ -529,8 +539,6 @@ def visit(self, node: LtBinOp) -> None: type: LtBinOp """ llvm_lhs, llvm_rhs, unsigned = self._load_binary_operands(node) - if is_vector(llvm_lhs) and is_vector(llvm_rhs): - raise Exception(f"Vector binop {node.op_code} not implemented.") result = self._emit_ordered_compare( "<", llvm_lhs, @@ -549,8 +557,6 @@ def visit(self, node: GtBinOp) -> None: type: GtBinOp """ llvm_lhs, llvm_rhs, unsigned = self._load_binary_operands(node) - if is_vector(llvm_lhs) and is_vector(llvm_rhs): - raise Exception(f"Vector binop {node.op_code} not implemented.") result = self._emit_ordered_compare( ">", llvm_lhs, @@ -569,8 +575,6 @@ def visit(self, node: LeBinOp) -> None: type: LeBinOp """ llvm_lhs, llvm_rhs, unsigned = self._load_binary_operands(node) - if is_vector(llvm_lhs) and is_vector(llvm_rhs): - raise Exception(f"Vector binop {node.op_code} not implemented.") result = self._emit_ordered_compare( "<=", llvm_lhs, @@ -589,8 +593,6 @@ def visit(self, node: GeBinOp) -> None: type: GeBinOp """ llvm_lhs, llvm_rhs, unsigned = self._load_binary_operands(node) - if is_vector(llvm_lhs) and is_vector(llvm_rhs): - raise Exception(f"Vector binop {node.op_code} not implemented.") result = self._emit_ordered_compare( ">=", llvm_lhs, @@ -611,7 +613,11 @@ def visit(self, node: EqBinOp) -> None: llvm_lhs, llvm_rhs, unsigned = self._load_binary_operands(node) if is_vector(llvm_lhs) and is_vector(llvm_rhs): - raise Exception(f"Vector binop {node.op_code} not implemented.") + result = self._emit_ordered_compare( + "==", llvm_lhs, llvm_rhs, unsigned=unsigned, name="vcmptmp" + ) + self.result_stack.append(result) + return if ( isinstance(llvm_lhs.type, ir.PointerType) @@ -654,7 +660,11 @@ def visit(self, node: NeBinOp) -> None: llvm_lhs, llvm_rhs, unsigned = self._load_binary_operands(node) if is_vector(llvm_lhs) and is_vector(llvm_rhs): - raise Exception(f"Vector binop {node.op_code} not implemented.") + result = self._emit_ordered_compare( + "!=", llvm_lhs, llvm_rhs, unsigned=unsigned, name="vcmptmp" + ) + self.result_stack.append(result) + return if ( isinstance(llvm_lhs.type, ir.PointerType) diff --git a/tests/test_binary_op.py b/tests/test_binary_op.py index af962103..29111f3a 100644 --- a/tests/test_binary_op.py +++ b/tests/test_binary_op.py @@ -5,9 +5,9 @@ import pytest from irx import astx +from irx.astx import PrintExpr from irx.builders.base import Builder from irx.builders.llvmliteir import Builder as LLVMBuilder -from irx.system import PrintExpr from .conftest import check_result @@ -255,3 +255,63 @@ def test_binary_op_logical_and_or( module.block.append(main_fn) check_result("build", builder, module, expected_output=expect) + + +@pytest.mark.parametrize( + "a_val, b_val, op, true_label, false_label", + [ + (2.0, 3.0, "<=", "le_true", "le_false"), + (3.0, 2.0, ">=", "ge_true", "ge_false"), + (2.0, 2.0, "==", "eq_true", "eq_false"), + (1.0, 2.0, "!=", "ne_true", "ne_false"), + ], +) +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_binary_op_float_comparison( + builder_class: type[Builder], + a_val: float, + b_val: float, + op: str, + true_label: str, + false_label: str, +) -> None: + """ + title: Test Float32 comparison operators cover fcmp_ordered paths. + parameters: + builder_class: + type: type[Builder] + a_val: + type: float + b_val: + type: float + op: + type: str + true_label: + type: str + false_label: + type: str + """ + builder = builder_class() + module = builder.module() + + cond = astx.BinaryOp( + op_code=op, + lhs=astx.LiteralFloat32(a_val), + rhs=astx.LiteralFloat32(b_val), + ) + then_blk = astx.Block() + then_blk.append(PrintExpr(astx.LiteralUTF8String(true_label))) + else_blk = astx.Block() + else_blk.append(PrintExpr(astx.LiteralUTF8String(false_label))) + if_stmt = astx.IfStmt(condition=cond, then=then_blk, else_=else_blk) + + proto = astx.FunctionPrototype( + name="main", args=astx.Arguments(), return_type=astx.Int32() + ) + block = astx.Block() + block.append(if_stmt) + block.append(astx.FunctionReturn(astx.LiteralInt32(0))) + fn = astx.FunctionDef(prototype=proto, body=block) + module.block.append(fn) + + check_result("build", builder, module, expected_output=true_label) diff --git a/tests/test_vector.py b/tests/test_vector.py index 537d3f12..4f256e99 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -53,7 +53,9 @@ def _make_binop_visitor( def mock_visit(node: Any, *args: Any, **kwargs: Any) -> Any: """ - title: Mock visit. + title: >- + Intercept Identifier visits to inject pre-built IR values for LHS, + RHS, and FMA_RHS; delegate all other nodes to the real visitor. parameters: node: type: Any @@ -83,7 +85,6 @@ def _run_vector_binop( rhs_val: ir.Value, unsigned: bool | None = None, fma_rhs: ir.Value | None = None, - fast_math: bool = False, ) -> ir.Value: """ title: Drive a BinaryOp through the visitor and return the result. @@ -98,8 +99,6 @@ def _run_vector_binop( type: bool | None fma_rhs: type: ir.Value | None - fast_math: - type: bool returns: type: ir.Value """ @@ -109,8 +108,6 @@ def _run_vector_binop( ) if unsigned is not None: bin_op.unsigned = unsigned - if fast_math: - bin_op.fast_math = True if fma_rhs is not None: bin_op.fma = True bin_op.fma_rhs = astx.Identifier("FMA_RHS") @@ -138,7 +135,7 @@ def _run_vector_binop( def _arith_id(case: tuple[Any, ...]) -> str: """ - title: Arith id. + title: Build a readable pytest ID string for an arithmetic test case. parameters: case: type: tuple[Any, Ellipsis] @@ -245,7 +242,7 @@ def test_int_vector_division(unsigned: bool, want: str, reject: str) -> None: def _splat_id(case: tuple[Any, ...]) -> str: """ - title: Splat id. + title: Build a readable pytest ID string for a scalar-splat test case. parameters: case: type: tuple[Any, Ellipsis] @@ -351,7 +348,7 @@ def test_unsigned_scalar_splatted_to_vector_zero_extends() -> None: def _cross_id(case: tuple[Any, ...]) -> str: """ - title: Cross id. + title: Build a pytest ID for a cross-precision FP test case. parameters: case: type: tuple[Any, Ellipsis] @@ -470,69 +467,18 @@ def test_fma_missing_fma_rhs_raises() -> None: @pytest.mark.parametrize( - "op, fast_math, expect_fast", - [ - ("+", False, False), - ("+", True, True), - ("-", True, True), - ("*", True, True), - ("/", True, True), - ], - ids=["no_fast", "add_fast", "sub_fast", "mul_fast", "div_fast"], + "op, raises", + [("+", False), ("%", True)], + ids=["success", "failure"], ) -def test_vector_float_binop_uses_node_fast_math( - op: str, - fast_math: bool, - expect_fast: bool, -) -> None: +def test_fast_math_flag_always_cleared(op: str, raises: bool) -> None: """ title: >- - Float vector BinaryOp nodes honour the node fast_math attribute when they - go through visit(). + _fast_math_enabled is reset to False after the op regardless of whether + it succeeds or raises. parameters: op: type: str - fast_math: - type: bool - expect_fast: - type: bool - """ - builder = setup_builder() - vec_ty = ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4) - v = ir.Constant(vec_ty, [1.0] * VEC4) - result = _run_vector_binop(op, v, v, fast_math=fast_math) - - assert isinstance(result.type, ir.VectorType) - if expect_fast: - assert "fast" in result.flags - else: - assert "fast" not in result.flags - - -@pytest.mark.parametrize( - "op, initial_fast_math, raises", - [ - ("+", False, False), - ("+", True, False), - ("%", False, True), - ("%", True, True), - ], - ids=["success_false", "success_true", "failure_false", "failure_true"], -) -def test_fast_math_flag_restored_after_vector_binop( - op: str, - initial_fast_math: bool, - raises: bool, -) -> None: - """ - title: >- - Vector BinaryOp visits restore the prior fast-math state after success or - failure. - parameters: - op: - type: str - initial_fast_math: - type: bool raises: type: bool """ @@ -540,7 +486,6 @@ def test_fast_math_flag_restored_after_vector_binop( vec_ty = ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4) v = ir.Constant(vec_ty, [1.0] * VEC4) patched = _make_binop_visitor(v, v) - patched.set_fast_math(initial_fast_math) bin_op = astx.BinaryOp(op, astx.Identifier("LHS"), astx.Identifier("RHS")) bin_op.fast_math = True @@ -550,40 +495,11 @@ def test_fast_math_flag_restored_after_vector_binop( patched.visit(bin_op) else: patched.visit(bin_op) - patched.result_stack.pop() - - assert patched._fast_math_enabled is initial_fast_math - - -@pytest.mark.parametrize( - "initial_fast_math", - [False, True], - ids=["restore_false", "restore_true"], -) -def test_fast_math_flag_restored_after_vector_fma( - initial_fast_math: bool, -) -> None: - """ - title: Vector FMA visits restore the prior fast-math state. - parameters: - initial_fast_math: - type: bool - """ - builder = setup_builder() - vec_ty = ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4) - v = ir.Constant(vec_ty, [1.0] * VEC4) - patched = _make_binop_visitor(v, v, v) - patched.set_fast_math(initial_fast_math) - - bin_op = astx.BinaryOp("*", astx.Identifier("LHS"), astx.Identifier("RHS")) - bin_op.fast_math = True - bin_op.fma = True - bin_op.fma_rhs = astx.Identifier("FMA_RHS") - - patched.visit(bin_op) - patched.result_stack.pop() + result = patched.result_stack.pop() + assert "fadd" in str(result) + assert isinstance(result.type, ir.VectorType) - assert patched._fast_math_enabled is initial_fast_math + assert patched._fast_math_enabled is False @pytest.mark.parametrize( @@ -613,43 +529,31 @@ def test_vector_size_mismatch_raises( _run_vector_binop("+", v1, v2) -def test_vector_element_type_mismatch_raises() -> None: +def test_vector_element_type_promotion() -> None: """ - title: Mismatched vector element types raise an exception. + title: Mismatched vector element types are promoted to the wider type. """ builder = setup_builder() v1 = ir.Constant(ir.VectorType(builder._llvm.INT32_TYPE, VEC2), [1] * VEC2) v2 = ir.Constant(ir.VectorType(builder._llvm.INT64_TYPE, VEC2), [1] * VEC2) - with pytest.raises(Exception, match="Vector element type mismatch"): - _run_vector_binop("+", v1, v2) + result = _run_vector_binop("+", v1, v2) + assert isinstance(result.type, ir.VectorType) + assert result.type.element == builder._llvm.INT64_TYPE + assert result.type.count == VEC2 @pytest.mark.parametrize( "op, match", [ ("%", r"Vector binop .* not implemented"), - ("==", r"Vector binop .* not implemented"), - ("!=", r"Vector binop .* not implemented"), - ("<", r"Vector binop .* not implemented"), - ("<=", r"Vector binop .* not implemented"), - (">", r"Vector binop .* not implemented"), - (">=", r"Vector binop .* not implemented"), ], ids=[ "unsupported_%", - "cmp_eq", - "cmp_ne", - "cmp_lt", - "cmp_le", - "cmp_gt", - "cmp_ge", ], ) def test_unsupported_vector_op_raises(op: str, match: str) -> None: """ - title: >- - Unsupported and unimplemented comparison operators all raise an - exception. + title: Unsupported vector binary operators raise an exception. parameters: op: type: str diff --git a/tests/test_vector_numeric_unification.py b/tests/test_vector_numeric_unification.py new file mode 100644 index 00000000..51c8fed7 --- /dev/null +++ b/tests/test_vector_numeric_unification.py @@ -0,0 +1,392 @@ +""" +title: Tests for heterogeneous vector-vector numeric unification (#270). +""" + +from typing import Any + +import pytest + +from irx import astx +from irx.builders.llvmliteir import Builder, Visitor +from llvmlite import ir + +VEC4 = 4 +VEC2 = 2 + + +def setup_builder() -> Visitor: + """ + title: Return a visitor with a live IRBuilder positioned inside main(). + returns: + type: Visitor + """ + main_builder = Builder() + visitor = main_builder.translator + func_type = ir.FunctionType(visitor._llvm.INT32_TYPE, []) + fn = ir.Function(visitor._llvm.module, func_type, name="main") + bb = fn.append_basic_block("entry") + visitor._llvm.ir_builder = ir.IRBuilder(bb) + return visitor + + +def _make_binop_visitor( + lhs_val: ir.Value, + rhs_val: ir.Value, + fma_rhs: ir.Value | None = None, +) -> Visitor: + """ + title: >- + Return a fresh visitor patched to inject pre-built IR values for the LHS, + RHS, and FMA_RHS identifiers. + parameters: + lhs_val: + type: ir.Value + rhs_val: + type: ir.Value + fma_rhs: + type: ir.Value | None + returns: + type: Visitor + """ + builder = setup_builder() + original_visit = builder.visit + + def mock_visit(node: Any, *args: Any, **kwargs: Any) -> Any: + """ + title: >- + Intercept Identifier visits to inject pre-built IR values for LHS, + RHS, and FMA_RHS; delegate all other nodes to the real visitor. + parameters: + node: + type: Any + args: + type: Any + variadic: positional + kwargs: + type: Any + variadic: keyword + returns: + type: Any + """ + if isinstance(node, astx.Identifier): + mapping = {"LHS": lhs_val, "RHS": rhs_val, "FMA_RHS": fma_rhs} + if node.name in mapping: + builder.result_stack.append(mapping[node.name]) + return + return original_visit(node, *args, **kwargs) + + builder.visit = mock_visit # type: ignore[method-assign] + return builder + + +def _run_vector_binop( + op_code: str, + lhs_val: ir.Value, + rhs_val: ir.Value, + unsigned: bool | None = None, + fma_rhs: ir.Value | None = None, +) -> ir.Value: + """ + title: Drive a BinaryOp through the visitor and return the result. + parameters: + op_code: + type: str + lhs_val: + type: ir.Value + rhs_val: + type: ir.Value + unsigned: + type: bool | None + fma_rhs: + type: ir.Value | None + returns: + type: ir.Value + """ + builder = _make_binop_visitor(lhs_val, rhs_val, fma_rhs) + bin_op = astx.BinaryOp( + op_code, astx.Identifier("LHS"), astx.Identifier("RHS") + ) + if unsigned is not None: + bin_op.unsigned = unsigned + if fma_rhs is not None: + bin_op.fma = True + bin_op.fma_rhs = astx.Identifier("FMA_RHS") + builder.visit(bin_op) + return builder.result_stack.pop() + + +# --------------------------------------------------------------------------- +# Gap 1: Vector-vector element-type promotion +# --------------------------------------------------------------------------- + + +class TestVectorVectorPromotion: + """ + title: Tests for heterogeneous vector-vector element-type promotion. + """ + + def test_int16_plus_int32_vector(self) -> None: + """ + title: vector + vector promotes both to vector. + """ + builder = setup_builder() + v1 = ir.Constant( + ir.VectorType(builder._llvm.INT16_TYPE, VEC4), [1] * VEC4 + ) + v2 = ir.Constant( + ir.VectorType(builder._llvm.INT32_TYPE, VEC4), [2] * VEC4 + ) + result = _run_vector_binop("+", v1, v2) + assert isinstance(result.type, ir.VectorType) + assert result.type.element == builder._llvm.INT32_TYPE + assert result.type.count == VEC4 + + def test_float_plus_double_vector(self) -> None: + """ + title: vector + vector promotes both to vector. + """ + builder = setup_builder() + v1 = ir.Constant( + ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4), [1.0] * VEC4 + ) + v2 = ir.Constant( + ir.VectorType(builder._llvm.DOUBLE_TYPE, VEC4), [2.0] * VEC4 + ) + result = _run_vector_binop("+", v1, v2) + assert isinstance(result.type, ir.VectorType) + assert result.type.element == builder._llvm.DOUBLE_TYPE + assert result.type.count == VEC4 + + def test_unsigned_int_vector_widening_uses_zext(self) -> None: + """ + title: >- + Unsigned vector + vector uses zext (not sext) for widening. + """ + builder = setup_builder() + v1 = ir.Constant( + ir.VectorType(builder._llvm.INT16_TYPE, VEC2), [1] * VEC2 + ) + v2 = ir.Constant( + ir.VectorType(builder._llvm.INT32_TYPE, VEC2), [2] * VEC2 + ) + lhs, _rhs = builder._unify_numeric_operands(v1, v2, unsigned=True) + assert lhs.type.element == builder._llvm.INT32_TYPE + assert getattr(lhs, "opname", "") == "zext" + + def test_signed_int_vector_widening_uses_sext(self) -> None: + """ + title: Signed vector + vector uses sext for widening. + """ + builder = setup_builder() + v1 = ir.Constant( + ir.VectorType(builder._llvm.INT16_TYPE, VEC2), [1] * VEC2 + ) + v2 = ir.Constant( + ir.VectorType(builder._llvm.INT32_TYPE, VEC2), [2] * VEC2 + ) + lhs, _rhs = builder._unify_numeric_operands(v1, v2, unsigned=False) + assert lhs.type.element == builder._llvm.INT32_TYPE + assert getattr(lhs, "opname", "") == "sext" + + def test_unsigned_int_to_float_vector_uses_uitofp(self) -> None: + """ + title: >- + Unsigned vector + vector converts int to float via + uitofp. + """ + builder = setup_builder() + v1 = ir.Constant( + ir.VectorType(builder._llvm.INT32_TYPE, VEC4), [1] * VEC4 + ) + v2 = ir.Constant( + ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4), [2.0] * VEC4 + ) + lhs, _rhs = builder._unify_numeric_operands(v1, v2, unsigned=True) + assert lhs.type.element == builder._llvm.FLOAT_TYPE + assert getattr(lhs, "opname", "") == "uitofp" + + def test_signed_int_to_float_vector_uses_sitofp(self) -> None: + """ + title: >- + Signed vector + vector converts int to float via sitofp. + """ + builder = setup_builder() + v1 = ir.Constant( + ir.VectorType(builder._llvm.INT32_TYPE, VEC4), [1] * VEC4 + ) + v2 = ir.Constant( + ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4), [2.0] * VEC4 + ) + lhs, _rhs = builder._unify_numeric_operands(v1, v2, unsigned=False) + assert lhs.type.element == builder._llvm.FLOAT_TYPE + assert getattr(lhs, "opname", "") == "sitofp" + + def test_lane_count_mismatch_still_raises(self) -> None: + """ + title: >- + Vectors with different lane counts still raise even when element + types differ. + """ + builder = setup_builder() + v1 = ir.Constant( + ir.VectorType(builder._llvm.INT16_TYPE, VEC4), [1] * VEC4 + ) + v2 = ir.Constant( + ir.VectorType(builder._llvm.INT32_TYPE, VEC2), [2] * VEC2 + ) + with pytest.raises(Exception, match="Vector size mismatch"): + builder._unify_numeric_operands(v1, v2) + + def test_same_element_type_returns_unchanged(self) -> None: + """ + title: Vectors with identical element types are returned unchanged. + """ + builder = setup_builder() + v1 = ir.Constant( + ir.VectorType(builder._llvm.INT32_TYPE, VEC4), [1] * VEC4 + ) + v2 = ir.Constant( + ir.VectorType(builder._llvm.INT32_TYPE, VEC4), [2] * VEC4 + ) + lhs, rhs = builder._unify_numeric_operands(v1, v2) + assert lhs is v1 + assert rhs is v2 + + +# --------------------------------------------------------------------------- +# Gap 2: Vector compare operations +# --------------------------------------------------------------------------- + + +_VEC_CMP_OPS = ["<", ">", "<=", ">=", "==", "!="] + + +class TestVectorCompare: + """ + title: Tests for vector compare operations via the BinaryOp visitor. + """ + + @pytest.mark.parametrize("op", _VEC_CMP_OPS, ids=_VEC_CMP_OPS) + def test_float_vector_compare(self, op: str) -> None: + """ + title: Float vector compares emit fcmp_ordered. + parameters: + op: + type: str + """ + builder = setup_builder() + vec_ty = ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4) + v1 = ir.Constant(vec_ty, [1.0] * VEC4) + v2 = ir.Constant(vec_ty, [2.0] * VEC4) + result = _run_vector_binop(op, v1, v2) + assert isinstance(result.type, ir.VectorType) + assert result.type.count == VEC4 + assert "fcmp" in str(result) + + @pytest.mark.parametrize("op", _VEC_CMP_OPS, ids=_VEC_CMP_OPS) + def test_int_vector_compare_signed(self, op: str) -> None: + """ + title: Signed int vector compares emit icmp. + parameters: + op: + type: str + """ + builder = setup_builder() + vec_ty = ir.VectorType(builder._llvm.INT32_TYPE, VEC4) + v1 = ir.Constant(vec_ty, [1] * VEC4) + v2 = ir.Constant(vec_ty, [2] * VEC4) + result = _run_vector_binop(op, v1, v2) + assert isinstance(result.type, ir.VectorType) + assert result.type.count == VEC4 + assert "icmp" in str(result) + + @pytest.mark.parametrize("op", _VEC_CMP_OPS, ids=_VEC_CMP_OPS) + def test_int_vector_compare_unsigned(self, op: str) -> None: + """ + title: Unsigned int vector compares emit icmp with unsigned predicates. + parameters: + op: + type: str + """ + builder = setup_builder() + vec_ty = ir.VectorType(builder._llvm.INT32_TYPE, VEC4) + v1 = ir.Constant(vec_ty, [1] * VEC4) + v2 = ir.Constant(vec_ty, [2] * VEC4) + result = _run_vector_binop(op, v1, v2, unsigned=True) + assert isinstance(result.type, ir.VectorType) + assert result.type.count == VEC4 + assert "icmp" in str(result) + + def test_heterogeneous_vector_compare(self) -> None: + """ + title: >- + vector < vector promotes to vector then + compares. + """ + builder = setup_builder() + v1 = ir.Constant( + ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4), [1.0] * VEC4 + ) + v2 = ir.Constant( + ir.VectorType(builder._llvm.DOUBLE_TYPE, VEC4), [2.0] * VEC4 + ) + result = _run_vector_binop("<", v1, v2) + assert isinstance(result.type, ir.VectorType) + assert result.type.count == VEC4 + assert "fcmp" in str(result) + + def test_scalar_vector_compare(self) -> None: + """ + title: Scalar < vector promotes and splats the scalar, then compares. + """ + builder = setup_builder() + vec_ty = ir.VectorType(builder._llvm.INT32_TYPE, VEC4) + v = ir.Constant(vec_ty, [1] * VEC4) + s = ir.Constant(builder._llvm.INT32_TYPE, 2) + result = _run_vector_binop("<", s, v) + assert isinstance(result.type, ir.VectorType) + assert result.type.count == VEC4 + assert "icmp" in str(result) + + +# --------------------------------------------------------------------------- +# Gap 3: FMA unification +# --------------------------------------------------------------------------- + + +class TestFMAUnification: + """ + title: Tests for FMA with heterogeneous operand types. + """ + + def test_fma_float_float_double_unifies(self) -> None: + """ + title: >- + FMA with vector * vector + vector promotes all + to vector. + """ + builder = setup_builder() + vf = ir.Constant( + ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4), [1.0] * VEC4 + ) + vd = ir.Constant( + ir.VectorType(builder._llvm.DOUBLE_TYPE, VEC4), [2.0] * VEC4 + ) + result = _run_vector_binop("*", vf, vf, fma_rhs=vd) + assert isinstance(result.type, ir.VectorType) + assert result.type.element == builder._llvm.DOUBLE_TYPE + assert result.type.count == VEC4 + + def test_fma_same_type_still_works(self) -> None: + """ + title: >- + FMA with matching types still works after replacing the hard error + with unification. + """ + builder = setup_builder() + vec_ty = ir.VectorType(builder._llvm.FLOAT_TYPE, VEC4) + v = ir.Constant(vec_ty, [2.0] * VEC4) + result = _run_vector_binop("*", v, v, fma_rhs=v) + assert isinstance(result.type, ir.VectorType) + assert result.type.element == builder._llvm.FLOAT_TYPE + assert result.type.count == VEC4 diff --git a/tests/test_while.py b/tests/test_while.py index 2e015551..10ef2387 100644 --- a/tests/test_while.py +++ b/tests/test_while.py @@ -5,6 +5,7 @@ import pytest from irx import astx +from irx.astx import PrintExpr from irx.builders.base import Builder from irx.builders.llvmliteir import Builder as LLVMBuilder from llvmlite import binding as llvm @@ -19,6 +20,8 @@ (astx.Int16, astx.LiteralInt16), (astx.Int8, astx.LiteralInt8), (astx.Int64, astx.LiteralInt64), + (astx.Float32, astx.LiteralFloat32), + (astx.Float64, astx.LiteralFloat64), ], ) @pytest.mark.parametrize( @@ -69,8 +72,15 @@ def test_while_expr( var_a = astx.Identifier("a") cond = astx.BinaryOp(op_code="<", lhs=var_a, rhs=literal_type(5)) - # Update: ++a - update = astx.UnaryOp(op_code="++", operand=var_a) + # Update: a = a + 1 (works for int and float; ++ only works for int) + update = astx.VariableAssignment( + name="a", + value=astx.BinaryOp( + op_code="+", + lhs=astx.Identifier("a"), + rhs=literal_type(1), + ), + ) # Body body = astx.Block() @@ -276,3 +286,56 @@ def test_while_false_condition( module.block.append(fn_main) check_result(action, builder, module, expected_file) + + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_while_float_condition( + builder_class: type[Builder], +) -> None: + """ + title: Test WhileStmt with a Float32 condition covers fcmp_ordered path. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + + # float a = 3.0 + init_var = astx.InlineVariableDeclaration( + "a", + type_=astx.Float32(), + value=astx.LiteralFloat32(3.0), + mutability=astx.MutabilityKind.mutable, + ) + + # condition: a (evaluates to float, hits fcmp_ordered 0.0) + var_a = astx.Identifier("a") + cond = var_a + + # body: a = a - 1.0 + dec = astx.VariableAssignment( + name="a", + value=astx.BinaryOp( + op_code="-", lhs=var_a, rhs=astx.LiteralFloat32(1.0) + ), + ) + body = astx.Block() + body.append(dec) + + while_expr = astx.WhileStmt(condition=cond, body=body) + + # Print "done" after loop to prove execution completed. + proto = astx.FunctionPrototype( + name="main", args=astx.Arguments(), return_type=astx.Int32() + ) + fn_block = astx.Block() + fn_block.append(init_var) + fn_block.append(while_expr) + fn_block.append(PrintExpr(astx.LiteralUTF8String("done"))) + fn_block.append(astx.FunctionReturn(astx.LiteralInt32(0))) + + fn_main = astx.FunctionDef(prototype=proto, body=fn_block) + module = builder.module() + module.block.append(fn_main) + + check_result("build", builder, module, expected_output="done")