From d3617319cc4a20969b57fc09b62a00d1d1b2abc4 Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Wed, 8 Apr 2026 15:15:56 +0000 Subject: [PATCH 1/4] feat: Complete the scalar numeric foundation --- README.md | 17 +- docs/semantic-contract.md | 40 ++++ src/irx/analysis/normalization.py | 15 +- src/irx/analysis/types.py | 241 ++++++++++++++++++++--- src/irx/analysis/validation.py | 25 +-- src/irx/builder/core.py | 240 +++++++++++++++++++++- src/irx/builder/lowering/binary_ops.py | 18 +- src/irx/builder/lowering/functions.py | 67 +++++-- src/irx/builder/lowering/system.py | 87 +++----- src/irx/builder/lowering/variables.py | 15 ++ src/irx/builder/protocols.py | 86 +++++++- tests/analysis/test_numeric_semantics.py | 100 ++++++++++ tests/test_numeric_foundation.py | 182 +++++++++++++++++ 13 files changed, 990 insertions(+), 143 deletions(-) create mode 100644 tests/analysis/test_numeric_semantics.py create mode 100644 tests/test_numeric_foundation.py diff --git a/README.md b/README.md index 01eae4e9..d61ccac1 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,8 @@ via `clang`. - **Literals:** `LiteralInt16`, `LiteralInt32`, `LiteralString` - **Variables:** `Variable`, `VariableDeclaration`, `InlineVariableDeclaration` - - **Ops:** `UnaryOp` (`++`, `--`), `BinaryOp` (`+ - * / < >`) with simple type - promotion + - **Ops:** `UnaryOp` (`++`, `--`), `BinaryOp` (`+ - * / < >`) with documented + scalar numeric promotion and cast rules - **Flow:** `IfStmt`, `ForCountLoopStmt`, `ForRangeLoopStmt` - **Functions:** `FunctionPrototype`, `Function`, `FunctionReturn`, `FunctionCall` @@ -162,6 +162,19 @@ The current MVP is intentionally narrow: primitive `int32` arrays, lifecycle operations, inspection, and C Data roundtrip support. No full Arrow container semantics are encoded directly in LLVM IR. +## Scalar Numeric Semantics + +IRx now treats scalar numerics as a stable substrate instead of an ad hoc +"simple promotion" layer: + +- one canonical promotion table for signed integers, unsigned integers, and + floats +- one canonical implicit-promotion vs explicit-cast policy +- comparisons always resolve to `Boolean` / LLVM `i1` + +The full contract lives in +[docs/semantic-contract.md](https://github.com/arxlang/irx/blob/main/docs/semantic-contract.md). + ## Testing ```bash diff --git a/docs/semantic-contract.md b/docs/semantic-contract.md index e569a510..f9a30ea9 100644 --- a/docs/semantic-contract.md +++ b/docs/semantic-contract.md @@ -49,6 +49,46 @@ For multi-module compilation, IRx also guarantees the following Lowering should consume this semantic metadata instead of re-deriving meaning from raw syntax. +## Scalar Numeric Foundation + +Binary scalar numerics use one canonical promotion table: + +| Operand mix | Promoted operand type | +| --------------------- | ---------------------------------------------------------------------------------------------------- | +| `float + float` | wider float | +| `float + integer` | float widened to cover the integer width floor (`16`, `32`, or `64` bits), capped at `Float64` | +| `signed + signed` | wider signed integer | +| `unsigned + unsigned` | wider unsigned integer | +| `signed + unsigned` | wider signed integer when the signed operand is strictly wider; otherwise the wider unsigned integer | + +Comparison operators (`<`, `>`, `<=`, `>=`, `==`, `!=`) promote their operands +with the same table and always return `Boolean` semantically and `i1` in LLVM +IR. + +### Canonical Cast Policy + +Implicit promotions in variable initializers, assignments, call arguments, and +returns are intentionally narrower than explicit casts: + +- same-type assignment is always allowed +- signed integers may widen to wider signed integers +- unsigned integers may widen to wider unsigned integers +- unsigned integers may widen to strictly wider signed integers +- integers may promote to floats when the target float width meets the same + `16`/`32`/`64` floor used by the numeric-promotion table +- floats may widen to wider floats +- implicit sign-changing integer casts to unsigned targets are rejected +- implicit narrowing casts are rejected +- implicit float-to-integer and numeric-to-boolean casts are rejected + +Explicit `Cast(...)` expressions allow the full scalar conversions: + +- numeric-to-numeric casts +- boolean-to-numeric casts using `0` and `1` +- numeric-to-boolean casts using `!= 0` or `!= 0.0` +- string-to-string casts +- numeric/boolean-to-string casts through runtime formatting + ## Error Boundaries - Semantic errors: invalid programs, unsupported semantic input, and import diff --git a/src/irx/analysis/normalization.py b/src/irx/analysis/normalization.py index 731492aa..9ecb3d2f 100644 --- a/src/irx/analysis/normalization.py +++ b/src/irx/analysis/normalization.py @@ -9,7 +9,11 @@ from irx import astx from irx.analysis.resolved_nodes import ResolvedOperator, SemanticFlags -from irx.analysis.types import is_unsigned_type +from irx.analysis.types import ( + common_numeric_type, + is_integer_type, + is_unsigned_type, +) def normalize_flags( @@ -31,10 +35,17 @@ def normalize_flags( type: SemanticFlags """ explicit_unsigned = getattr(node, "unsigned", None) + promoted_type = common_numeric_type(lhs_type, rhs_type) unsigned = ( bool(explicit_unsigned) if explicit_unsigned is not None - else is_unsigned_type(lhs_type) or is_unsigned_type(rhs_type) + else ( + is_integer_type(promoted_type) and is_unsigned_type(promoted_type) + ) + or ( + promoted_type is None + and (is_unsigned_type(lhs_type) or is_unsigned_type(rhs_type)) + ) ) return SemanticFlags( unsigned=unsigned, diff --git a/src/irx/analysis/types.py b/src/irx/analysis/types.py index 3f9cda23..4c743d19 100644 --- a/src/irx/analysis/types.py +++ b/src/irx/analysis/types.py @@ -36,6 +36,24 @@ astx.Float32: BIT_WIDTH_32, astx.Float64: BIT_WIDTH_64, } +_SIGNED_INTEGERS_BY_WIDTH: dict[int, type[astx.DataType]] = { + BIT_WIDTH_8: astx.Int8, + BIT_WIDTH_16: astx.Int16, + BIT_WIDTH_32: astx.Int32, + BIT_WIDTH_64: astx.Int64, +} +_UNSIGNED_INTEGERS_BY_WIDTH: dict[int, type[astx.DataType]] = { + BIT_WIDTH_8: astx.UInt8, + BIT_WIDTH_16: astx.UInt16, + BIT_WIDTH_32: astx.UInt32, + BIT_WIDTH_64: astx.UInt64, + BIT_WIDTH_128: astx.UInt128, +} +_FLOATS_BY_WIDTH: dict[int, type[astx.DataType]] = { + BIT_WIDTH_16: astx.Float16, + BIT_WIDTH_32: astx.Float32, + BIT_WIDTH_64: astx.Float64, +} @public @@ -200,6 +218,111 @@ def bit_width(type_: astx.DataType | None) -> int: return _BIT_WIDTHS.get(type(type_), 0) +def _type_for_width( + width: int, + table: dict[int, type[astx.DataType]], +) -> astx.DataType | None: + """ + title: Instantiate a type family member for one width. + parameters: + width: + type: int + table: + type: dict[int, type[astx.DataType]] + returns: + type: astx.DataType | None + """ + type_cls = table.get(width) + if type_cls is None: + return None + return type_cls() + + +def _min_float_width_for_integer_width(width: int) -> int: + """ + title: Return the float width used when integers promote with floats. + parameters: + width: + type: int + returns: + type: int + """ + if width <= BIT_WIDTH_16: + return BIT_WIDTH_16 + if width <= BIT_WIDTH_32: + return BIT_WIDTH_32 + return BIT_WIDTH_64 + + +def _common_integer_type( + lhs: astx.DataType, + rhs: astx.DataType, +) -> astx.DataType | None: + """ + title: Return the canonical promoted type for two integers. + parameters: + lhs: + type: astx.DataType + rhs: + type: astx.DataType + returns: + type: astx.DataType | None + """ + lhs_width = bit_width(lhs) + rhs_width = bit_width(rhs) + + if is_signed_integer_type(lhs) and is_signed_integer_type(rhs): + return _type_for_width( + max(lhs_width, rhs_width), + _SIGNED_INTEGERS_BY_WIDTH, + ) + + if is_unsigned_type(lhs) and is_unsigned_type(rhs): + return _type_for_width( + max(lhs_width, rhs_width), + _UNSIGNED_INTEGERS_BY_WIDTH, + ) + + signed_width = lhs_width if is_signed_integer_type(lhs) else rhs_width + unsigned_width = lhs_width if is_unsigned_type(lhs) else rhs_width + + if signed_width > unsigned_width: + return _type_for_width(signed_width, _SIGNED_INTEGERS_BY_WIDTH) + return _type_for_width( + max(lhs_width, rhs_width), + _UNSIGNED_INTEGERS_BY_WIDTH, + ) + + +def _common_float_type( + lhs: astx.DataType, + rhs: astx.DataType, +) -> astx.DataType | None: + """ + title: Return the canonical promoted type for operands with floats. + parameters: + lhs: + type: astx.DataType + rhs: + type: astx.DataType + returns: + type: astx.DataType | None + """ + float_width = max( + bit_width(type_) for type_ in (lhs, rhs) if is_float_type(type_) + ) + integer_width = max( + (bit_width(type_) for type_ in (lhs, rhs) if is_integer_type(type_)), + default=0, + ) + target_width = max( + float_width, + _min_float_width_for_integer_width(integer_width), + ) + target_width = min(target_width, BIT_WIDTH_64) + return _type_for_width(target_width, _FLOATS_BY_WIDTH) + + @public def common_numeric_type( lhs: astx.DataType | None, @@ -221,33 +344,90 @@ def common_numeric_type( return None if is_float_type(lhs) or is_float_type(rhs): - widest = max(bit_width(lhs), bit_width(rhs)) - if widest <= BIT_WIDTH_16: - return astx.Float16() - if widest <= BIT_WIDTH_32: - return astx.Float32() - return astx.Float64() - - width = max(bit_width(lhs), bit_width(rhs)) - use_unsigned = is_unsigned_type(lhs) or is_unsigned_type(rhs) - if use_unsigned: - if width <= BIT_WIDTH_8: - return astx.UInt8() - if width <= BIT_WIDTH_16: - return astx.UInt16() - if width <= BIT_WIDTH_32: - return astx.UInt32() - if width <= BIT_WIDTH_64: - return astx.UInt64() - return astx.UInt128() - - if width <= BIT_WIDTH_8: - return astx.Int8() - if width <= BIT_WIDTH_16: - return astx.Int16() - if width <= BIT_WIDTH_32: - return astx.Int32() - return astx.Int64() + return _common_float_type(lhs, rhs) + return _common_integer_type(lhs, rhs) + + +def _is_safe_integer_assignment( + target: astx.DataType, + value: astx.DataType, +) -> bool: + """ + title: Return whether one integer can implicitly promote into another. + parameters: + target: + type: astx.DataType + value: + type: astx.DataType + returns: + type: bool + """ + target_width = bit_width(target) + value_width = bit_width(value) + + if is_signed_integer_type(target) and is_signed_integer_type(value): + return target_width >= value_width + if is_unsigned_type(target) and is_unsigned_type(value): + return target_width >= value_width + if is_signed_integer_type(target) and is_unsigned_type(value): + return target_width > value_width + return False + + +def _is_safe_float_assignment( + target: astx.DataType, + value: astx.DataType, +) -> bool: + """ + title: Return whether a value can implicitly promote to a float target. + parameters: + target: + type: astx.DataType + value: + type: astx.DataType + returns: + type: bool + """ + target_width = bit_width(target) + if is_float_type(value): + return target_width >= bit_width(value) + if is_integer_type(value): + return target_width >= _min_float_width_for_integer_width( + bit_width(value) + ) + return False + + +@public +def is_explicitly_castable( + source: astx.DataType | None, + target: astx.DataType | None, +) -> bool: + """ + title: Return whether an explicit Cast expression is allowed. + parameters: + source: + type: astx.DataType | None + target: + type: astx.DataType | None + returns: + type: bool + """ + if source is None or target is None: + return True + if is_assignable(target, source): + return True + if (is_numeric_type(source) or is_boolean_type(source)) and ( + is_numeric_type(target) or is_boolean_type(target) + ): + return True + if isinstance(target, (astx.String, astx.UTF8String)): + return ( + is_string_type(source) + or is_numeric_type(source) + or (is_boolean_type(source)) + ) + return False @public @@ -269,9 +449,10 @@ def is_assignable( return True if same_type(target, value): return True - if is_numeric_type(target) and is_numeric_type(value): - common = common_numeric_type(target, value) - return common is not None and same_type(target, common) + if is_integer_type(target) and is_integer_type(value): + return _is_safe_integer_assignment(target, value) + if is_float_type(target) and is_numeric_type(value): + return _is_safe_float_assignment(target, value) if is_string_type(target) and is_string_type(value): return True if is_none_type(target) and is_none_type(value): diff --git a/src/irx/analysis/validation.py b/src/irx/analysis/validation.py index ec9e937a..362c15bd 100644 --- a/src/irx/analysis/validation.py +++ b/src/irx/analysis/validation.py @@ -12,7 +12,10 @@ from irx import astx from irx.analysis.diagnostics import DiagnosticBag from irx.analysis.resolved_nodes import SemanticFunction -from irx.analysis.types import is_assignable, is_boolean_type, is_numeric_type +from irx.analysis.types import ( + is_assignable, + is_explicitly_castable, +) TIME_PARTS_HOUR_MINUTE = 2 TIME_PARTS_HOUR_MINUTE_SECOND = 3 @@ -103,13 +106,7 @@ def validate_cast( """ if source_type is None or target_type is None: return - if is_assignable(target_type, source_type): - return - if _is_numeric_cast_type(source_type) and _is_numeric_cast_type( - target_type - ): - return - if isinstance(target_type, (astx.String, astx.UTF8String)): + if is_explicitly_castable(source_type, target_type): return diagnostics.add( f"Unsupported cast from {source_type} to {target_type}", @@ -117,18 +114,6 @@ def validate_cast( ) -def _is_numeric_cast_type(type_: astx.DataType | None) -> bool: - """ - title: Is numeric cast type. - parameters: - type_: - type: astx.DataType | None - returns: - type: bool - """ - return is_numeric_type(type_) or is_boolean_type(type_) - - def validate_literal_time(value: str) -> time: """ title: Validate an astx time literal. diff --git a/src/irx/builder/core.py b/src/irx/builder/core.py index 29df4972..6f384fcd 100644 --- a/src/irx/builder/core.py +++ b/src/irx/builder/core.py @@ -27,6 +27,14 @@ mangle_function_name, mangle_struct_name, ) +from irx.analysis.types import ( + bit_width, + common_numeric_type, + is_boolean_type, + is_float_type, + is_integer_type, + is_unsigned_type, +) from irx.builder.base import BuilderVisitor from irx.builder.runtime import safe_pop from irx.builder.runtime.registry import ( @@ -266,6 +274,7 @@ class VisitorCore(BuilderVisitor): _emitted_function_bodies: set[str] entry_function_symbol_id: str | None _fast_math_enabled: bool + _current_function_return_type: astx.DataType | None target: llvm.TargetRef target_machine: llvm.TargetMachine @@ -292,6 +301,7 @@ def __init__( self._emitted_function_bodies = set() self.entry_function_symbol_id = None self._fast_math_enabled = False + self._current_function_return_type = None self.initialize() self.target = llvm.Target.from_default_triple() @@ -748,6 +758,183 @@ def _is_numeric_value(self, value: ir.Value) -> bool: return isinstance(elem_ty, ir.IntType) or is_fp_type(elem_ty) return isinstance(value.type, ir.IntType) or is_fp_type(value.type) + def _resolved_ast_type( + self, node: astx.AST | None + ) -> astx.DataType | None: + """ + title: Resolved ast type. + parameters: + node: + type: astx.AST | None + returns: + type: astx.DataType | None + """ + if node is None: + return None + semantic = getattr(node, "semantic", None) + resolved_type = getattr(semantic, "resolved_type", None) + if isinstance(resolved_type, astx.DataType): + return resolved_type + fallback = getattr(node, "type_", None) + return fallback if isinstance(fallback, astx.DataType) else None + + def _llvm_type_for_ast_type( + self, type_: astx.DataType | None + ) -> ir.Type | None: + """ + title: Llvm type for ast type. + parameters: + type_: + type: astx.DataType | None + returns: + type: ir.Type | None + """ + if type_ is None: + return None + type_name = type_.__class__.__name__.lower() + return self._llvm.get_data_type(type_name) + + def _bool_value_from_numeric( + self, + value: ir.Value, + source_type: astx.DataType | None, + *, + name: str, + ) -> ir.Value: + """ + title: Convert one numeric value to boolean truthiness. + parameters: + value: + type: ir.Value + source_type: + type: astx.DataType | None + name: + type: str + returns: + type: ir.Value + """ + if is_boolean_type(source_type): + return value + if is_float_type(source_type): + zero = ir.Constant(value.type, 0.0) + return self._llvm.ir_builder.fcmp_ordered("!=", value, zero, name) + if is_integer_type(source_type): + zero = ir.Constant(value.type, 0) + return self._llvm.ir_builder.icmp_unsigned("!=", value, zero, name) + raise Exception(f"Unsupported boolean conversion from {source_type!r}") + + def _cast_ast_value( + self, + value: ir.Value, + *, + source_type: astx.DataType | None, + target_type: astx.DataType | None, + ) -> ir.Value: + """ + title: Cast one lowered value using semantic scalar types. + parameters: + value: + type: ir.Value + source_type: + type: astx.DataType | None + target_type: + type: astx.DataType | None + returns: + type: ir.Value + """ + if source_type is None or target_type is None: + return value + + target_llvm_type = self._llvm_type_for_ast_type(target_type) + if target_llvm_type is None: + return value + if value.type == target_llvm_type: + return value + + builder = self._llvm.ir_builder + + if is_boolean_type(target_type): + return self._bool_value_from_numeric( + value, + source_type, + name="boolcast", + ) + + if is_boolean_type(source_type): + if is_integer_type(target_type): + return builder.zext(value, target_llvm_type, "bool_zext") + if is_float_type(target_type): + return builder.uitofp(value, target_llvm_type, "bool_to_fp") + + if is_integer_type(source_type) and is_integer_type(target_type): + source_width = bit_width(source_type) + target_width = bit_width(target_type) + + if source_width == target_width: + return value + if source_width < target_width: + if is_unsigned_type(source_type): + return builder.zext(value, target_llvm_type, "zext") + return builder.sext(value, target_llvm_type, "sext") + return builder.trunc(value, target_llvm_type, "trunc") + + if is_integer_type(source_type) and is_float_type(target_type): + if is_unsigned_type(source_type): + return builder.uitofp(value, target_llvm_type, "uitofp") + return builder.sitofp(value, target_llvm_type, "sitofp") + + if is_float_type(source_type) and is_integer_type(target_type): + if is_unsigned_type(target_type): + return builder.fptoui(value, target_llvm_type, "fptoui") + return builder.fptosi(value, target_llvm_type, "fptosi") + + if is_float_type(source_type) and is_float_type(target_type): + if bit_width(source_type) < bit_width(target_type): + return builder.fpext(value, target_llvm_type, "fpext") + return builder.fptrunc(value, target_llvm_type, "fptrunc") + + raise Exception( + f"Unsupported scalar cast from {source_type!r} to {target_type!r}" + ) + + def _coerce_numeric_operands_for_types( + self, + lhs: ir.Value, + rhs: ir.Value, + *, + lhs_type: astx.DataType | None, + rhs_type: astx.DataType | None, + ) -> tuple[ir.Value, ir.Value]: + """ + title: Coerce numeric operands from semantic types. + parameters: + lhs: + type: ir.Value + rhs: + type: ir.Value + lhs_type: + type: astx.DataType | None + rhs_type: + type: astx.DataType | None + returns: + type: tuple[ir.Value, ir.Value] + """ + target_type = common_numeric_type(lhs_type, rhs_type) + if target_type is None: + return lhs, rhs + return ( + self._cast_ast_value( + lhs, + source_type=lhs_type, + target_type=target_type, + ), + self._cast_ast_value( + rhs, + source_type=rhs_type, + target_type=target_type, + ), + ) + def _unify_numeric_operands( self, lhs: ir.Value, @@ -793,12 +980,31 @@ def _unify_numeric_operands( lhs_base_ty = lhs.type rhs_base_ty = rhs.type if is_fp_type(lhs_base_ty) or is_fp_type(rhs_base_ty): - candidates = [ + float_candidates = [ type_ for type_ in (lhs_base_ty, rhs_base_ty) if is_fp_type(type_) ] - target_scalar_ty = self._select_float_type(candidates) + integer_width = max( + ( + getattr(type_, "width", 0) + for type_ in (lhs_base_ty, rhs_base_ty) + if is_int_type(type_) + ), + default=0, + ) + float_width = max( + ( + self._float_bit_width(type_) + for type_ in float_candidates + ), + default=0, + ) + target_width = max( + float_width, + self._min_float_width_for_integer_bits(integer_width), + ) + target_scalar_ty = self._float_type_from_width(target_width) else: lhs_width = getattr(lhs_base_ty, "width", 0) rhs_width = getattr(rhs_base_ty, "width", 0) @@ -820,6 +1026,21 @@ def _unify_numeric_operands( return lhs, rhs + def _min_float_width_for_integer_bits(self, width: int) -> int: + """ + title: Minimum float width for integer bits. + parameters: + width: + type: int + returns: + type: int + """ + if width <= FLOAT16_BITS: + return FLOAT16_BITS + if width <= FLOAT32_BITS: + return FLOAT32_BITS + return FLOAT64_BITS + def _select_float_type(self, candidates: list[ir.Type]) -> ir.Type: """ title: Select float type. @@ -1735,13 +1956,18 @@ def _handle_string_comparison( raise Exception(f"String comparison operator {op} not implemented") def _normalize_int_for_printf( - self, value: ir.Value + self, + value: ir.Value, + *, + unsigned: bool = False, ) -> tuple[ir.Value, str]: """ title: Normalize int for printf. parameters: value: type: ir.Value + unsigned: + type: bool returns: type: tuple[ir.Value, str] """ @@ -1750,12 +1976,14 @@ def _normalize_int_for_printf( raise Exception("Expected integer value") width = value.type.width if width < int64_width: - if width == 1: + if width == 1 or unsigned: arg = self._llvm.ir_builder.zext(value, self._llvm.INT64_TYPE) - else: - arg = self._llvm.ir_builder.sext(value, self._llvm.INT64_TYPE) + return arg, "%llu" + arg = self._llvm.ir_builder.sext(value, self._llvm.INT64_TYPE) return arg, "%lld" if width == int64_width: + if unsigned: + return value, "%llu" return value, "%lld" raise Exception( "Casting integers wider than 64 bits to string is not supported" diff --git a/src/irx/builder/lowering/binary_ops.py b/src/irx/builder/lowering/binary_ops.py index 392b37af..9442a4d3 100644 --- a/src/irx/builder/lowering/binary_ops.py +++ b/src/irx/builder/lowering/binary_ops.py @@ -88,16 +88,25 @@ def _load_binary_operands( raise Exception("codegen: Invalid lhs/rhs") unsigned = uses_unsigned_semantics(node) + lhs_type = self._resolved_ast_type(node.lhs) + rhs_type = self._resolved_ast_type(node.rhs) if ( unify_numeric and self._is_numeric_value(llvm_lhs) and self._is_numeric_value(llvm_rhs) ): - llvm_lhs, llvm_rhs = self._unify_numeric_operands( + llvm_lhs, llvm_rhs = self._coerce_numeric_operands_for_types( llvm_lhs, llvm_rhs, - unsigned=unsigned, + lhs_type=lhs_type, + rhs_type=rhs_type, ) + if llvm_lhs.type != llvm_rhs.type: + llvm_lhs, llvm_rhs = self._unify_numeric_operands( + llvm_lhs, + llvm_rhs, + unsigned=unsigned, + ) return llvm_lhs, llvm_rhs, unsigned @@ -365,6 +374,11 @@ def visit(self, node: AssignmentBinOp) -> None: llvm_rhs = safe_pop(self.result_stack) if llvm_rhs is None: raise Exception("codegen: Invalid rhs expression.") + llvm_rhs = self._cast_ast_value( + llvm_rhs, + source_type=self._resolved_ast_type(node.rhs), + target_type=self._resolved_ast_type(node), + ) llvm_lhs = self.named_values.get(lhs_key) if not llvm_lhs: diff --git a/src/irx/builder/lowering/functions.py b/src/irx/builder/lowering/functions.py index 5c81a5b8..64b465f3 100644 --- a/src/irx/builder/lowering/functions.py +++ b/src/irx/builder/lowering/functions.py @@ -36,11 +36,26 @@ def visit(self, node: astx.FunctionCall) -> None: raise Exception("codegen: Incorrect # arguments passed.") llvm_args = [] - for arg in node.args: + resolved_function = getattr( + getattr(node, "semantic", None), + "resolved_function", + None, + ) + param_types = ( + [param.type_ for param in resolved_function.args] + if resolved_function is not None + else [None] * len(node.args) + ) + for arg, param_type in zip(node.args, param_types): self.visit_child(arg) llvm_arg = safe_pop(self.result_stack) if llvm_arg is None: raise Exception("codegen: Invalid callee argument.") + llvm_arg = self._cast_ast_value( + llvm_arg, + source_type=self._resolved_ast_type(arg), + target_type=param_type, + ) llvm_args.append(llvm_arg) result = self._llvm.ir_builder.call(callee_f, llvm_args, "calltmp") @@ -66,26 +81,33 @@ def visit(self, node: astx.FunctionDef) -> None: basic_block = fn.append_basic_block("entry") self._llvm.ir_builder = ir.IRBuilder(basic_block) - - for idx, llvm_arg in enumerate(fn.args): - arg_ast = proto.args.nodes[idx] - type_str = arg_ast.type_.__class__.__name__.lower() - arg_type = self._llvm.get_data_type(type_str) - symbol_key = semantic_symbol_key(arg_ast, llvm_arg.name) - alloca = self._llvm.ir_builder.alloca(arg_type, name=llvm_arg.name) - self._llvm.ir_builder.store(llvm_arg, alloca) - self.named_values[symbol_key] = alloca - - self.visit_child(node.body) - if not self._llvm.ir_builder.block.is_terminated: - return_type = fn.function_type.return_type - if isinstance(return_type, ir.VoidType): - self._llvm.ir_builder.ret_void() - else: - raise SyntaxError( - f"Function '{proto.name}' with return type " - f"'{return_type}' is missing a return statement" + previous_return_type = self._current_function_return_type + self._current_function_return_type = proto.return_type + + try: + for idx, llvm_arg in enumerate(fn.args): + arg_ast = proto.args.nodes[idx] + type_str = arg_ast.type_.__class__.__name__.lower() + arg_type = self._llvm.get_data_type(type_str) + symbol_key = semantic_symbol_key(arg_ast, llvm_arg.name) + alloca = self._llvm.ir_builder.alloca( + arg_type, name=llvm_arg.name ) + self._llvm.ir_builder.store(llvm_arg, alloca) + self.named_values[symbol_key] = alloca + + self.visit_child(node.body) + if not self._llvm.ir_builder.block.is_terminated: + return_type = fn.function_type.return_type + if isinstance(return_type, ir.VoidType): + self._llvm.ir_builder.ret_void() + else: + raise SyntaxError( + f"Function '{proto.name}' with return type " + f"'{return_type}' is missing a return statement" + ) + finally: + self._current_function_return_type = previous_return_type self._emitted_function_bodies.add(function_key) self.result_stack.append(fn) @@ -138,6 +160,11 @@ def visit(self, node: astx.FunctionReturn) -> None: retval = None if retval is not None: + retval = self._cast_ast_value( + retval, + source_type=self._resolved_ast_type(node.value), + target_type=self._current_function_return_type, + ) fn_return_type = ( self._llvm.ir_builder.function.function_type.return_type ) diff --git a/src/irx/builder/lowering/system.py b/src/irx/builder/lowering/system.py index 4f84d159..bb196811 100644 --- a/src/irx/builder/lowering/system.py +++ b/src/irx/builder/lowering/system.py @@ -7,10 +7,11 @@ from llvmlite import ir from irx import astx +from irx.analysis.types import is_boolean_type, is_unsigned_type from irx.builder.core import VisitorCore from irx.builder.protocols import VisitorMixinBase from irx.builder.runtime import safe_pop -from irx.builder.types import is_fp_type, is_int_type +from irx.builder.types import is_int_type from irx.typecheck import typechecked @@ -29,60 +30,25 @@ def visit(self, node: astx.Cast) -> None: if value is None: raise Exception("Invalid value in Cast") - target_type_str = node.target_type.__class__.__name__.lower() - target_type = self._llvm.get_data_type(target_type_str) - - if value.type == target_type: - self.result_stack.append(value) - return - - result: ir.Value - if is_int_type(value.type) and is_int_type(target_type): - if value.type.width < target_type.width: - result = self._llvm.ir_builder.sext( - value, target_type, "cast_int_up" - ) - else: - result = self._llvm.ir_builder.trunc( - value, target_type, "cast_int_down" - ) - elif is_int_type(value.type) and is_fp_type(target_type): - result = self._llvm.ir_builder.sitofp( - value, target_type, "cast_int_to_fp" - ) - elif is_fp_type(value.type) and is_int_type(target_type): - result = self._llvm.ir_builder.fptosi( - value, target_type, "cast_fp_to_int" - ) - elif isinstance(value.type, ir.FloatType) and isinstance( - target_type, ir.HalfType - ): - result = self._llvm.ir_builder.fptrunc( - value, target_type, "cast_fp_to_half" - ) - elif isinstance(value.type, ir.HalfType) and isinstance( - target_type, ir.FloatType - ): - result = self._llvm.ir_builder.fpext( - value, target_type, "cast_half_to_fp" - ) - elif isinstance(value.type, ir.FloatType) and isinstance( - target_type, ir.FloatType - ): - if value.type.width < target_type.width: - result = self._llvm.ir_builder.fpext( - value, target_type, "cast_fp_up" - ) - else: - result = self._llvm.ir_builder.fptrunc( - value, target_type, "cast_fp_down" - ) - elif target_type in ( + source_type = self._resolved_ast_type(node.value) + target_type = node.target_type + target_llvm_type = self._llvm_type_for_ast_type(target_type) + if target_llvm_type in ( self._llvm.ASCII_STRING_TYPE, self._llvm.UTF8_STRING_TYPE, ): + if ( + isinstance(value.type, ir.PointerType) + and value.type.pointee == self._llvm.INT8_TYPE + ): + self.result_stack.append(value) + return if is_int_type(value.type): - arg, fmt_str = self._normalize_int_for_printf(value) + arg, fmt_str = self._normalize_int_for_printf( + value, + unsigned=is_unsigned_type(source_type) + or is_boolean_type(source_type), + ) fmt_gv = self._get_or_create_format_global(fmt_str) ptr = self._snprintf_heap(fmt_gv, [arg]) self.result_stack.append(ptr) @@ -102,13 +68,13 @@ def visit(self, node: astx.Cast) -> None: self.result_stack.append(ptr) return raise Exception( - f"Unsupported cast from {value.type} to {target_type}" + f"Unsupported cast from {value.type} to {target_llvm_type}" ) - else: - raise Exception( - f"Unsupported cast from {value.type} to {target_type}" - ) - + result = self._cast_ast_value( + value, + source_type=source_type, + target_type=target_type, + ) self.result_stack.append(result) @VisitorCore.visit.dispatch @@ -124,6 +90,7 @@ def visit(self, node: astx.PrintExpr) -> None: if message_value is None: raise Exception("Invalid message in PrintExpr") + message_source_type = self._resolved_ast_type(node.message) message_type = message_value.type ptr: ir.Value if ( @@ -132,7 +99,11 @@ def visit(self, node: astx.PrintExpr) -> None: ): ptr = message_value elif is_int_type(message_type): - int_arg, int_fmt = self._normalize_int_for_printf(message_value) + int_arg, int_fmt = self._normalize_int_for_printf( + message_value, + unsigned=is_unsigned_type(message_source_type) + or is_boolean_type(message_source_type), + ) int_fmt_gv = self._get_or_create_format_global(int_fmt) ptr = self._snprintf_heap(int_fmt_gv, [int_arg]) elif isinstance( diff --git a/src/irx/builder/lowering/variables.py b/src/irx/builder/lowering/variables.py index 8aa2b9c4..2d63488f 100644 --- a/src/irx/builder/lowering/variables.py +++ b/src/irx/builder/lowering/variables.py @@ -39,6 +39,11 @@ def visit(self, expr: astx.VariableAssignment) -> None: llvm_value = safe_pop(self.result_stack) if llvm_value is None: raise Exception("codegen: Invalid value in VariableAssignment.") + llvm_value = self._cast_ast_value( + llvm_value, + source_type=self._resolved_ast_type(expr.value), + target_type=self._resolved_ast_type(expr), + ) llvm_var = self.named_values.get(var_key) if not llvm_var: @@ -85,6 +90,11 @@ def visit(self, node: astx.VariableDeclaration) -> None: init_val = safe_pop(self.result_stack) if init_val is None: raise Exception("Initializer code generation failed.") + init_val = self._cast_ast_value( + init_val, + source_type=self._resolved_ast_type(node.value), + target_type=node.type_, + ) if type_str == "string": alloca = self.create_entry_block_alloca( @@ -148,6 +158,11 @@ def visit(self, node: astx.InlineVariableDeclaration) -> None: init_val = safe_pop(self.result_stack) if init_val is None: raise Exception("Initializer code generation failed.") + init_val = self._cast_ast_value( + init_val, + source_type=self._resolved_ast_type(node.value), + target_type=node.type_, + ) elif "float" in type_str: init_val = ir.Constant(self._llvm.get_data_type(type_str), 0.0) else: diff --git a/src/irx/builder/protocols.py b/src/irx/builder/protocols.py index 4b8d9796..0a9bd9aa 100644 --- a/src/irx/builder/protocols.py +++ b/src/irx/builder/protocols.py @@ -41,6 +41,8 @@ class VisitorProtocol(BaseVisitorProtocol, Protocol): type: str | None _fast_math_enabled: type: bool + _current_function_return_type: + type: astx.DataType | None target: type: llvm.TargetRef target_machine: @@ -58,6 +60,7 @@ class VisitorProtocol(BaseVisitorProtocol, Protocol): llvm_structs_by_qualified_name: dict[str, ir.IdentifiedStructType] entry_function_symbol_id: str | None _fast_math_enabled: bool + _current_function_return_type: astx.DataType | None target: llvm.TargetRef target_machine: llvm.TargetMachine @@ -326,17 +329,23 @@ def _emit_runtime_subscript_lookup( raise NotImplementedError def _normalize_int_for_printf( - self, _value: ir.Value + self, + _value: ir.Value, + *, + unsigned: bool = False, ) -> tuple[ir.Value, str]: """ title: Normalize int for printf. parameters: _value: type: ir.Value + unsigned: + type: bool returns: type: tuple[ir.Value, str] """ - ... + _ = unsigned + return cast(tuple[ir.Value, str], (None, "")) def _snprintf_heap( self, _fmt_gv: ir.GlobalVariable, _args: list[ir.Value] @@ -393,6 +402,8 @@ class VisitorMixinBase: type: str | None _fast_math_enabled: type: bool + _current_function_return_type: + type: astx.DataType | None target: type: llvm.TargetRef target_machine: @@ -410,6 +421,7 @@ class VisitorMixinBase: llvm_structs_by_qualified_name: dict[str, ir.IdentifiedStructType] entry_function_symbol_id: str | None _fast_math_enabled: bool + _current_function_return_type: astx.DataType | None target: llvm.TargetRef target_machine: llvm.TargetMachine @@ -535,6 +547,68 @@ def _is_numeric_value(self, _value: ir.Value) -> bool: """ return False + def _resolved_ast_type( + self, _node: astx.AST | None + ) -> astx.DataType | None: + """ + title: Resolved ast type. + parameters: + _node: + type: astx.AST | None + returns: + type: astx.DataType | None + """ + return cast(astx.DataType | None, None) + + def _cast_ast_value( + self, + _value: ir.Value, + *, + source_type: astx.DataType | None, + target_type: astx.DataType | None, + ) -> ir.Value: + """ + title: Cast ast value. + parameters: + _value: + type: ir.Value + source_type: + type: astx.DataType | None + target_type: + type: astx.DataType | None + returns: + type: ir.Value + """ + _ = source_type + _ = target_type + return cast(ir.Value, None) + + def _coerce_numeric_operands_for_types( + self, + _lhs: ir.Value, + _rhs: ir.Value, + *, + lhs_type: astx.DataType | None, + rhs_type: astx.DataType | None, + ) -> tuple[ir.Value, ir.Value]: + """ + title: Coerce numeric operands for types. + parameters: + _lhs: + type: ir.Value + _rhs: + type: ir.Value + lhs_type: + type: astx.DataType | None + rhs_type: + type: astx.DataType | None + returns: + type: tuple[ir.Value, ir.Value] + """ + _ = lhs_type + _ = rhs_type + return cast(tuple[ir.Value, ir.Value], (None, None)) + def _unify_numeric_operands( self, _lhs: ir.Value, @@ -697,16 +771,22 @@ def _emit_runtime_subscript_lookup( _ = unsigned def _normalize_int_for_printf( - self, _value: ir.Value + self, + _value: ir.Value, + *, + unsigned: bool = False, ) -> tuple[ir.Value, str]: """ title: Normalize int for printf. parameters: _value: type: ir.Value + unsigned: + type: bool returns: type: tuple[ir.Value, str] """ + _ = unsigned return cast(tuple[ir.Value, str], (None, "")) def _snprintf_heap( diff --git a/tests/analysis/test_numeric_semantics.py b/tests/analysis/test_numeric_semantics.py new file mode 100644 index 00000000..b221fc8a --- /dev/null +++ b/tests/analysis/test_numeric_semantics.py @@ -0,0 +1,100 @@ +""" +title: Tests for scalar numeric semantics. +""" + +from __future__ import annotations + +import pytest + +from irx import astx +from irx.analysis import SemanticError, analyze + + +def _semantic_type(node: astx.AST) -> astx.DataType | None: + """ + title: Return the resolved semantic type attached to one node. + parameters: + node: + type: astx.AST + returns: + type: astx.DataType | None + """ + return getattr(getattr(node, "semantic", None), "resolved_type", None) + + +def test_analyze_promotes_int64_and_float32_to_float64() -> None: + """ + title: Mixed int64 and float32 arithmetic should promote to float64. + """ + expr = astx.BinaryOp( + "+", + astx.LiteralInt64(1), + astx.LiteralFloat32(2.0), + ) + + analyze(expr) + + resolved_type = _semantic_type(expr) + assert resolved_type is not None + assert resolved_type.__class__ is astx.Float64 + + +def test_analyze_marks_signed_result_when_signed_operand_is_wider() -> None: + """ + title: Wider signed integers should keep signed semantics against unsigned. + """ + expr = astx.BinaryOp( + "/", + astx.LiteralInt32(9), + astx.LiteralUInt16(2), + ) + + analyze(expr) + + semantic = getattr(expr, "semantic") + assert semantic.resolved_type.__class__ is astx.Int32 + assert semantic.semantic_flags.unsigned is False + + +def test_unsigned_result_when_unsigned_operand_is_not_narrower() -> None: + """ + title: >- + Equal-or-wider unsigned integers should drive mixed integer semantics. + """ + expr = astx.BinaryOp( + ">", + astx.LiteralInt16(-1), + astx.LiteralUInt32(1), + ) + + analyze(expr) + + semantic = getattr(expr, "semantic") + assert semantic.resolved_type.__class__ is astx.Boolean + assert semantic.semantic_flags.unsigned is True + + +def test_analyze_rejects_implicit_signed_to_unsigned_assignment() -> None: + """ + title: >- + Implicit signed-to-unsigned assignment should require an explicit cast. + """ + module = astx.Module() + proto = astx.FunctionPrototype( + "main", + args=astx.Arguments(), + return_type=astx.Int32(), + ) + body = astx.Block() + body.append( + astx.VariableDeclaration( + name="value", + type_=astx.UInt32(), + value=astx.LiteralInt32(1), + ) + ) + body.append(astx.FunctionReturn(astx.LiteralInt32(0))) + module.block.append(astx.FunctionDef(prototype=proto, body=body)) + + with pytest.raises(SemanticError, match="Cannot assign value of type"): + analyze(module) diff --git a/tests/test_numeric_foundation.py b/tests/test_numeric_foundation.py new file mode 100644 index 00000000..5cdb1d6e --- /dev/null +++ b/tests/test_numeric_foundation.py @@ -0,0 +1,182 @@ +""" +title: Tests for scalar numeric lowering and codegen. +""" + +from __future__ import annotations + +from irx import astx +from irx.builder import Builder +from irx.system import Cast, PrintExpr + +from .conftest import assert_ir_parses, make_main_module + + +def test_translate_implicit_integer_widening_for_variable_initializers() -> ( + None +): + """ + title: Variable initializers should widen before storing into the alloca. + """ + module = make_main_module( + astx.VariableDeclaration( + name="value", + type_=astx.Int32(), + value=astx.LiteralInt16(7), + ), + astx.FunctionReturn(astx.LiteralInt32(0)), + ) + + ir_text = Builder().translate(module) + assert "sext i16 7 to i32" in ir_text + assert_ir_parses(ir_text) + + +def test_translate_mixed_int64_and_float32_promotes_to_double() -> None: + """ + title: Mixed int64 and float32 arithmetic should lower through double. + """ + module = make_main_module( + astx.VariableDeclaration( + name="left", + type_=astx.Int64(), + value=astx.LiteralInt64(1), + ), + astx.VariableDeclaration( + name="right", + type_=astx.Float32(), + value=astx.LiteralFloat32(2.0), + ), + astx.InlineVariableDeclaration( + name="result", + type_=astx.Float64(), + value=astx.BinaryOp( + "+", + astx.Identifier("left"), + astx.Identifier("right"), + ), + ), + astx.FunctionReturn(astx.LiteralInt32(0)), + ) + + ir_text = Builder().translate(module) + + assert "sitofp i64" in ir_text + assert "fpext float" in ir_text + assert "fadd double" in ir_text + assert_ir_parses(ir_text) + + +def test_translate_function_call_implicitly_promotes_arguments() -> None: + """ + title: >- + Calls should apply the same implicit promotions as semantic analysis. + """ + module = astx.Module() + + echo_proto = astx.FunctionPrototype( + "echo", + args=astx.Arguments(astx.Argument("value", astx.Int32())), + return_type=astx.Int32(), + ) + echo_body = astx.Block() + echo_body.append(astx.FunctionReturn(astx.Identifier("value"))) + module.block.append(astx.FunctionDef(prototype=echo_proto, body=echo_body)) + + main_proto = astx.FunctionPrototype( + "main", + args=astx.Arguments(), + return_type=astx.Int32(), + ) + main_body = astx.Block() + main_body.append( + PrintExpr(astx.FunctionCall("echo", [astx.LiteralInt16(7)])) + ) + main_body.append(astx.FunctionReturn(astx.LiteralInt32(0))) + module.block.append(astx.FunctionDef(prototype=main_proto, body=main_body)) + + ir_text = Builder().translate(module) + + assert "sext i16 7 to i32" in ir_text + assert 'call i32 @"main__echo"(i32' in ir_text + assert_ir_parses(ir_text) + + +def test_mixed_signed_and_unsigned_comparison_uses_canonical_promotion() -> ( + None +): + """ + title: Mixed signed and unsigned comparisons should follow the shared rule. + """ + module = make_main_module( + astx.FunctionReturn( + astx.BinaryOp( + ">", + astx.LiteralInt16(-1), + astx.LiteralUInt32(1), + ) + ), + return_type=astx.Boolean(), + ) + + ir_text = Builder().translate(module) + + assert "sext i16 -1 to i32" in ir_text + assert "icmp ugt i32" in ir_text + assert_ir_parses(ir_text) + + +def test_translate_bool_numeric_casts_use_truthy_zero_one_semantics() -> None: + """ + title: Bool casts should use 0/1 truthiness instead of sign-extension. + """ + module = make_main_module( + astx.VariableDeclaration( + name="flag", + type_=astx.Boolean(), + value=astx.LiteralBoolean(True), + ), + astx.VariableDeclaration( + name="number", + type_=astx.Int32(), + value=astx.LiteralInt32(2), + ), + PrintExpr( + Cast( + value=astx.Identifier("flag"), + target_type=astx.Int32(), + ) + ), + PrintExpr( + Cast( + value=astx.Identifier("number"), + target_type=astx.Boolean(), + ) + ), + astx.FunctionReturn(astx.LiteralInt32(0)), + ) + + ir_text = Builder().translate(module) + + assert "zext i1" in ir_text + assert "icmp ne i32" in ir_text + assert_ir_parses(ir_text) + + +def test_translate_cast_uint32_to_string_formats_as_unsigned() -> None: + """ + title: Unsigned integer string casts should preserve the unsigned value. + """ + module = make_main_module( + PrintExpr( + Cast( + value=astx.LiteralUInt32(4294967295), + target_type=astx.String(), + ) + ), + astx.FunctionReturn(astx.LiteralInt32(0)), + ) + + ir_text = Builder().translate(module) + + assert "%llu" in ir_text + assert_ir_parses(ir_text) From 0d4cc6c0d017b1fb7e9a900a1b37d0a763480a63 Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Wed, 8 Apr 2026 15:27:09 +0000 Subject: [PATCH 2/4] apply reviewer suggestions --- src/irx/analysis/types.py | 9 ++-- src/irx/builder/core.py | 24 +++------ src/irx/builder/lowering/binary_ops.py | 16 +++--- tests/analysis/test_numeric_semantics.py | 55 +++++++++++++++++++ tests/test_numeric_foundation.py | 69 ++++++++++++++++++++++++ 5 files changed, 145 insertions(+), 28 deletions(-) diff --git a/src/irx/analysis/types.py b/src/irx/analysis/types.py index 4c743d19..d0d15a40 100644 --- a/src/irx/analysis/types.py +++ b/src/irx/analysis/types.py @@ -238,9 +238,10 @@ def _type_for_width( return type_cls() -def _min_float_width_for_integer_width(width: int) -> int: +@public +def float_promotion_width_for_integer_width(width: int) -> int: """ - title: Return the float width used when integers promote with floats. + title: Return the float width floor used when integers promote with floats. parameters: width: type: int @@ -317,7 +318,7 @@ def _common_float_type( ) target_width = max( float_width, - _min_float_width_for_integer_width(integer_width), + float_promotion_width_for_integer_width(integer_width), ) target_width = min(target_width, BIT_WIDTH_64) return _type_for_width(target_width, _FLOATS_BY_WIDTH) @@ -392,7 +393,7 @@ def _is_safe_float_assignment( if is_float_type(value): return target_width >= bit_width(value) if is_integer_type(value): - return target_width >= _min_float_width_for_integer_width( + return target_width >= float_promotion_width_for_integer_width( bit_width(value) ) return False diff --git a/src/irx/builder/core.py b/src/irx/builder/core.py index 6f384fcd..b6cfa41c 100644 --- a/src/irx/builder/core.py +++ b/src/irx/builder/core.py @@ -30,6 +30,7 @@ from irx.analysis.types import ( bit_width, common_numeric_type, + float_promotion_width_for_integer_width, is_boolean_type, is_float_type, is_integer_type, @@ -942,7 +943,11 @@ def _unify_numeric_operands( unsigned: bool = False, ) -> tuple[ir.Value, ir.Value]: """ - title: Unify numeric operands. + title: Unify numeric operands for raw LLVM values. + summary: >- + This is a fallback helper for low-level builder/test usage when + semantic operand types are unavailable. Normal AST lowering should + prefer semantic-aware coercion instead. parameters: lhs: type: ir.Value @@ -1002,7 +1007,7 @@ def _unify_numeric_operands( ) target_width = max( float_width, - self._min_float_width_for_integer_bits(integer_width), + float_promotion_width_for_integer_width(integer_width), ) target_scalar_ty = self._float_type_from_width(target_width) else: @@ -1026,21 +1031,6 @@ def _unify_numeric_operands( return lhs, rhs - def _min_float_width_for_integer_bits(self, width: int) -> int: - """ - title: Minimum float width for integer bits. - parameters: - width: - type: int - returns: - type: int - """ - if width <= FLOAT16_BITS: - return FLOAT16_BITS - if width <= FLOAT32_BITS: - return FLOAT32_BITS - return FLOAT64_BITS - def _select_float_type(self, candidates: list[ir.Type]) -> ir.Type: """ title: Select float type. diff --git a/src/irx/builder/lowering/binary_ops.py b/src/irx/builder/lowering/binary_ops.py index 9442a4d3..35325b67 100644 --- a/src/irx/builder/lowering/binary_ops.py +++ b/src/irx/builder/lowering/binary_ops.py @@ -95,13 +95,15 @@ def _load_binary_operands( and self._is_numeric_value(llvm_lhs) and self._is_numeric_value(llvm_rhs) ): - llvm_lhs, llvm_rhs = self._coerce_numeric_operands_for_types( - llvm_lhs, - llvm_rhs, - lhs_type=lhs_type, - rhs_type=rhs_type, - ) - if llvm_lhs.type != llvm_rhs.type: + if lhs_type is not None and rhs_type is not None: + llvm_lhs, llvm_rhs = self._coerce_numeric_operands_for_types( + llvm_lhs, + llvm_rhs, + lhs_type=lhs_type, + rhs_type=rhs_type, + ) + else: + # This is a low-level fallback for raw LLVM helper use only. llvm_lhs, llvm_rhs = self._unify_numeric_operands( llvm_lhs, llvm_rhs, diff --git a/tests/analysis/test_numeric_semantics.py b/tests/analysis/test_numeric_semantics.py index b221fc8a..7dacc81e 100644 --- a/tests/analysis/test_numeric_semantics.py +++ b/tests/analysis/test_numeric_semantics.py @@ -74,6 +74,23 @@ def test_unsigned_result_when_unsigned_operand_is_not_narrower() -> None: assert semantic.semantic_flags.unsigned is True +def test_analyze_marks_unsigned_result_at_equal_width_boundary() -> None: + """ + title: Equal-width signed and unsigned integers should promote to unsigned. + """ + expr = astx.BinaryOp( + "+", + astx.LiteralInt32(-1), + astx.LiteralUInt32(1), + ) + + analyze(expr) + + semantic = getattr(expr, "semantic") + assert semantic.resolved_type.__class__ is astx.UInt32 + assert semantic.semantic_flags.unsigned is True + + def test_analyze_rejects_implicit_signed_to_unsigned_assignment() -> None: """ title: >- @@ -98,3 +115,41 @@ def test_analyze_rejects_implicit_signed_to_unsigned_assignment() -> None: with pytest.raises(SemanticError, match="Cannot assign value of type"): analyze(module) + + +def test_rejects_implicit_unsigned_to_same_width_signed_assignment() -> None: + """ + title: Unsigned integers need a strictly wider signed target implicitly. + """ + module = astx.Module() + proto = astx.FunctionPrototype( + "main", + args=astx.Arguments(), + return_type=astx.Int32(), + ) + body = astx.Block() + body.append( + astx.VariableDeclaration( + name="value", + type_=astx.Int32(), + value=astx.LiteralUInt32(1), + ) + ) + body.append(astx.FunctionReturn(astx.LiteralInt32(0))) + module.block.append(astx.FunctionDef(prototype=proto, body=body)) + + with pytest.raises(SemanticError, match="Cannot assign value of type"): + analyze(module) + + +def test_analyze_rejects_explicit_string_to_int_cast() -> None: + """ + title: Non-numeric explicit casts should still be rejected semantically. + """ + expr = astx.Cast( + value=astx.LiteralString("7"), + target_type=astx.Int32(), + ) + + with pytest.raises(SemanticError, match="Unsupported cast"): + analyze(expr) diff --git a/tests/test_numeric_foundation.py b/tests/test_numeric_foundation.py index 5cdb1d6e..1461b5a0 100644 --- a/tests/test_numeric_foundation.py +++ b/tests/test_numeric_foundation.py @@ -66,6 +66,30 @@ def test_translate_mixed_int64_and_float32_promotes_to_double() -> None: assert_ir_parses(ir_text) +def test_translate_implicit_integer_widening_for_reassignment() -> None: + """ + title: Reassignments should widen values before storing them. + """ + module = make_main_module( + astx.VariableDeclaration( + name="value", + type_=astx.Int32(), + value=astx.LiteralInt32(0), + mutability=astx.MutabilityKind.mutable, + ), + astx.VariableAssignment( + name="value", + value=astx.LiteralInt16(7), + ), + astx.FunctionReturn(astx.LiteralInt32(0)), + ) + + ir_text = Builder().translate(module) + + assert "sext i16 7 to i32" in ir_text + assert_ir_parses(ir_text) + + def test_translate_function_call_implicitly_promotes_arguments() -> None: """ title: >- @@ -101,6 +125,29 @@ def test_translate_function_call_implicitly_promotes_arguments() -> None: assert_ir_parses(ir_text) +def test_translate_implicit_return_coercion_uses_declared_function_type() -> ( + None +): + """ + title: Returns should coerce to the declared function return type. + """ + module = astx.Module() + proto = astx.FunctionPrototype( + "main", + args=astx.Arguments(), + return_type=astx.Int32(), + ) + body = astx.Block() + body.append(astx.FunctionReturn(astx.LiteralInt16(7))) + module.block.append(astx.FunctionDef(prototype=proto, body=body)) + + ir_text = Builder().translate(module) + + assert "sext i16 7 to i32" in ir_text + assert "ret i32" in ir_text + assert_ir_parses(ir_text) + + def test_mixed_signed_and_unsigned_comparison_uses_canonical_promotion() -> ( None ): @@ -125,6 +172,28 @@ def test_mixed_signed_and_unsigned_comparison_uses_canonical_promotion() -> ( assert_ir_parses(ir_text) +def test_wider_signed_integer_keeps_signed_comparison_semantics() -> None: + """ + title: Wider signed integers should force signed comparison lowering. + """ + module = make_main_module( + astx.FunctionReturn( + astx.BinaryOp( + ">", + astx.LiteralInt64(-1), + astx.LiteralUInt32(1), + ) + ), + return_type=astx.Boolean(), + ) + + ir_text = Builder().translate(module) + + assert "zext i32 1 to i64" in ir_text + assert "icmp sgt i64" in ir_text + assert_ir_parses(ir_text) + + def test_translate_bool_numeric_casts_use_truthy_zero_one_semantics() -> None: """ title: Bool casts should use 0/1 truthiness instead of sign-extension. From 4ce312ee9fe7da4561345e6da694631f0bf1706a Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Wed, 8 Apr 2026 16:27:51 +0000 Subject: [PATCH 3/4] fix tests --- src/irx/builder/lowering/binary_ops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/irx/builder/lowering/binary_ops.py b/src/irx/builder/lowering/binary_ops.py index 35325b67..7171974d 100644 --- a/src/irx/builder/lowering/binary_ops.py +++ b/src/irx/builder/lowering/binary_ops.py @@ -9,6 +9,7 @@ from llvmlite import ir from irx import astx +from irx.analysis.types import common_numeric_type from irx.astx.binary_op import ( SPECIALIZED_BINARY_OP_EXTRA, AddBinOp, @@ -90,12 +91,13 @@ def _load_binary_operands( unsigned = uses_unsigned_semantics(node) lhs_type = self._resolved_ast_type(node.lhs) rhs_type = self._resolved_ast_type(node.rhs) + semantic_numeric_type = common_numeric_type(lhs_type, rhs_type) if ( unify_numeric and self._is_numeric_value(llvm_lhs) and self._is_numeric_value(llvm_rhs) ): - if lhs_type is not None and rhs_type is not None: + if semantic_numeric_type is not None: llvm_lhs, llvm_rhs = self._coerce_numeric_operands_for_types( llvm_lhs, llvm_rhs, From d4196744cb515a0f443e1610cdb06b4ce396e28f Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Wed, 8 Apr 2026 17:41:08 +0000 Subject: [PATCH 4/4] add typechecked --- AGENTS.md | 17 +++- contributing.md | 14 +++ docs/contributing.md | 16 +++ src/irx/__init__.py | 3 + src/irx/analysis/api.py | 4 + src/irx/analysis/contract.py | 1 + src/irx/analysis/module_symbols.py | 10 ++ src/irx/analysis/normalization.py | 3 + src/irx/analysis/session.py | 1 + src/irx/analysis/symbols.py | 2 + src/irx/analysis/types.py | 22 +++++ src/irx/analysis/typing.py | 3 + src/irx/analysis/validation.py | 8 ++ src/irx/astx/__init__.py | 5 + src/irx/astx/binary_op.py | 2 + src/irx/builder/core.py | 10 ++ src/irx/builder/runtime/arrow/feature.py | 13 +++ src/irx/builder/runtime/feature_libc.py | 6 ++ src/irx/builder/runtime/features.py | 1 + src/irx/builder/runtime/linking.py | 4 + src/irx/builder/runtime/registry.py | 1 + src/irx/builder/types.py | 2 + src/irx/builder/vector.py | 5 + src/irx/typecheck.py | 19 ++-- tests/test_typechecked_policy.py | 120 +++++++++++++++++++++-- 25 files changed, 272 insertions(+), 20 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index e0bab329..6f888a46 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -176,14 +176,21 @@ API. ## Runtime Type Checking -- Use `irx.typecheck.typechecked` on every concrete class defined under - `src/irx/`. -- Keep `@public` or `@private` outermost, then `@typechecked`, then +- Use `irx.typecheck.typechecked` on every module-level function and every + concrete class defined under `src/irx/`. +- Class-level `@typechecked` is the default way to cover methods. Do not add + per-method decorators just to mirror the policy unless a class cannot be + decorated, and if that ever happens document the reason. +- For functions, keep `@public` or `@private` outermost and place `@typechecked` + on the implementation boundary. Usually that means directly under `@public` or + `@private`; for wrappers like `@lru_cache(...)`, keep `@typechecked` closest + to the original function so Typeguard can instrument it. +- For classes, keep `@public` or `@private` outermost, then `@typechecked`, then `@dataclass(...)` so generated dataclass methods are instrumented. - Exempt only `Protocol` definitions and type-checking-only helper stubs that are intentionally kept out of the runtime MRO. -- If a concrete class truly needs an exemption, document the reason and update - `tests/test_typechecked_policy.py` in the same change. +- If a function or concrete class truly needs an exemption, document the reason + and update `tests/test_typechecked_policy.py` in the same change. ## Working In `irx.builder` diff --git a/contributing.md b/contributing.md index 5c02e587..bc99fdcb 100644 --- a/contributing.md +++ b/contributing.md @@ -26,6 +26,20 @@ makim tests.linter makim tests.unittest ``` +## Runtime type checking + +- Use `irx.typecheck.typechecked` on every module-level function and every + concrete class under `src/irx`. +- Class decorators are the default way to cover methods; do not add per-method + decorators unless a class cannot be decorated. +- Keep `@public` or `@private` outermost and place `@typechecked` on the + implementation boundary; for wrappers like `@lru_cache(...)`, that means + keeping `@typechecked` closest to the original function. +- Keep class decorators ordered as `@public` or `@private`, then `@typechecked`, + then `@dataclass(...)`. +- Run `pytest tests/test_typechecked_policy.py -q` when you touch decorator + coverage or add an exemption. + ## Full guidelines Please see the full contributing guide for project layout, workflow, and release diff --git a/docs/contributing.md b/docs/contributing.md index b3d12cac..110eb665 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -95,6 +95,22 @@ Ready to contribute? Here’s how to set up `irx` for local development. $ makim tests.unittest ``` +## Runtime Type Checking + +IRx keeps runtime type checking on by default for its own code under `src/irx`. + +- Use `irx.typecheck.typechecked` on every module-level function and every + concrete class. +- Methods are expected to be covered through the class decorator; avoid adding + per-method decorators unless the class itself cannot be decorated. +- Keep `@public` or `@private` outermost and place `@typechecked` on the + implementation boundary; for wrappers like `@lru_cache(...)`, that means + keeping `@typechecked` closest to the original function. +- Keep class decorators ordered as `@public` or `@private`, then `@typechecked`, + then `@dataclass(...)`. +- If you need an exemption for a `Protocol` or a typing-only stub, document it + clearly and update `tests/test_typechecked_policy.py` in the same change. + 6. Commit your changes and push your branch to GitHub: ```bash diff --git a/src/irx/__init__.py b/src/irx/__init__.py index 6dc59b77..3544351e 100644 --- a/src/irx/__init__.py +++ b/src/irx/__init__.py @@ -4,7 +4,10 @@ from importlib import metadata as importlib_metadata +from irx.typecheck import typechecked + +@typechecked def get_version() -> str: """ title: Return the program version. diff --git a/src/irx/analysis/api.py b/src/irx/analysis/api.py index 5d65461c..0f6270b8 100644 --- a/src/irx/analysis/api.py +++ b/src/irx/analysis/api.py @@ -19,6 +19,7 @@ ParsedModule, ) from irx.analysis.session import CompilationSession +from irx.typecheck import typechecked __all__ = [ "SemanticAnalyzer", @@ -29,6 +30,7 @@ @public +@typechecked def analyze(node: astx.AST) -> astx.AST: """ title: Analyze one AST root and attach semantic sidecars. @@ -46,6 +48,7 @@ def analyze(node: astx.AST) -> astx.AST: @public +@typechecked def analyze_module(module: astx.Module) -> astx.Module: """ title: Analyze an AST module. @@ -62,6 +65,7 @@ def analyze_module(module: astx.Module) -> astx.Module: @public +@typechecked def analyze_modules( root: ParsedModule, resolver: ImportResolver, diff --git a/src/irx/analysis/contract.py b/src/irx/analysis/contract.py index 3bd3a3f5..1bc08546 100644 --- a/src/irx/analysis/contract.py +++ b/src/irx/analysis/contract.py @@ -233,6 +233,7 @@ class SemanticContract: @public +@typechecked def get_semantic_contract() -> SemanticContract: """ title: Return the stable public semantic contract. diff --git a/src/irx/analysis/module_symbols.py b/src/irx/analysis/module_symbols.py index 4b2b8f04..b089a311 100644 --- a/src/irx/analysis/module_symbols.py +++ b/src/irx/analysis/module_symbols.py @@ -12,10 +12,12 @@ from public import public from irx.analysis.module_interfaces import ModuleKey +from irx.typecheck import typechecked _SEGMENT_RE = re.compile(r"[A-Za-z0-9_]+") +@typechecked def _split_segments(value: str) -> list[str]: """ title: Split a string into LLVM-friendly segments. @@ -33,6 +35,7 @@ def _split_segments(value: str) -> list[str]: return [f"x{ord(char):02x}" for char in value] +@typechecked def _mangle_parts(*parts: str) -> str: """ title: Mangle string parts into a deterministic LLVM name. @@ -50,6 +53,7 @@ def _mangle_parts(*parts: str) -> str: @public +@typechecked def function_key(module_key: ModuleKey, name: str) -> tuple[ModuleKey, str]: """ title: Return a module-aware function registry key. @@ -65,6 +69,7 @@ def function_key(module_key: ModuleKey, name: str) -> tuple[ModuleKey, str]: @public +@typechecked def struct_key(module_key: ModuleKey, name: str) -> tuple[ModuleKey, str]: """ title: Return a module-aware struct registry key. @@ -80,6 +85,7 @@ def struct_key(module_key: ModuleKey, name: str) -> tuple[ModuleKey, str]: @public +@typechecked def qualified_function_name(module_key: ModuleKey, name: str) -> str: """ title: Return a qualified semantic function name. @@ -95,6 +101,7 @@ def qualified_function_name(module_key: ModuleKey, name: str) -> str: @public +@typechecked def qualified_struct_name(module_key: ModuleKey, name: str) -> str: """ title: Return a qualified semantic struct name. @@ -110,6 +117,7 @@ def qualified_struct_name(module_key: ModuleKey, name: str) -> str: @public +@typechecked def qualified_local_name( module_key: ModuleKey, kind: str, @@ -134,6 +142,7 @@ def qualified_local_name( @public +@typechecked def mangle_function_name(module_key: ModuleKey, function_name: str) -> str: """ title: Return a deterministic LLVM function name. @@ -149,6 +158,7 @@ def mangle_function_name(module_key: ModuleKey, function_name: str) -> str: @public +@typechecked def mangle_struct_name(module_key: ModuleKey, struct_name: str) -> str: """ title: Return a deterministic LLVM struct name. diff --git a/src/irx/analysis/normalization.py b/src/irx/analysis/normalization.py index 9ecb3d2f..dc74bb89 100644 --- a/src/irx/analysis/normalization.py +++ b/src/irx/analysis/normalization.py @@ -14,8 +14,10 @@ is_integer_type, is_unsigned_type, ) +from irx.typecheck import typechecked +@typechecked def normalize_flags( node: astx.AST, *, @@ -55,6 +57,7 @@ def normalize_flags( ) +@typechecked def normalize_operator( op_code: str, *, diff --git a/src/irx/analysis/session.py b/src/irx/analysis/session.py index 109ee1e3..76429329 100644 --- a/src/irx/analysis/session.py +++ b/src/irx/analysis/session.py @@ -22,6 +22,7 @@ from irx.typecheck import typechecked +@typechecked def _module_import_specifier(node: astx.ImportFromStmt) -> str: """ title: Return the resolver-facing module specifier for import-from nodes. diff --git a/src/irx/analysis/symbols.py b/src/irx/analysis/symbols.py index 2ba467b5..95c2ceaa 100644 --- a/src/irx/analysis/symbols.py +++ b/src/irx/analysis/symbols.py @@ -12,9 +12,11 @@ from irx.analysis.module_interfaces import ModuleKey from irx.analysis.module_symbols import qualified_local_name from irx.analysis.resolved_nodes import SemanticSymbol +from irx.typecheck import typechecked @public +@typechecked def variable_symbol( symbol_id: str, module_key: ModuleKey, diff --git a/src/irx/analysis/types.py b/src/irx/analysis/types.py index d0d15a40..479618b7 100644 --- a/src/irx/analysis/types.py +++ b/src/irx/analysis/types.py @@ -10,6 +10,7 @@ from public import public from irx import astx +from irx.typecheck import typechecked INT_TYPES = (astx.Int8, astx.Int16, astx.Int32, astx.Int64) UINT_TYPES = (astx.UInt8, astx.UInt16, astx.UInt32, astx.UInt64, astx.UInt128) @@ -57,6 +58,7 @@ @public +@typechecked def clone_type(type_: astx.DataType) -> astx.DataType: """ title: Clone an AST type by class. @@ -70,6 +72,7 @@ def clone_type(type_: astx.DataType) -> astx.DataType: @public +@typechecked def same_type(lhs: astx.DataType | None, rhs: astx.DataType | None) -> bool: """ title: Return whether two AST types share the same class. @@ -87,6 +90,7 @@ def same_type(lhs: astx.DataType | None, rhs: astx.DataType | None) -> bool: @public +@typechecked def is_integer_type(type_: astx.DataType | None) -> bool: """ title: Is integer type. @@ -100,6 +104,7 @@ def is_integer_type(type_: astx.DataType | None) -> bool: @public +@typechecked def is_signed_integer_type(type_: astx.DataType | None) -> bool: """ title: Is signed integer type. @@ -113,6 +118,7 @@ def is_signed_integer_type(type_: astx.DataType | None) -> bool: @public +@typechecked def is_unsigned_type(type_: astx.DataType | None) -> bool: """ title: Is unsigned type. @@ -126,6 +132,7 @@ def is_unsigned_type(type_: astx.DataType | None) -> bool: @public +@typechecked def is_float_type(type_: astx.DataType | None) -> bool: """ title: Is float type. @@ -139,6 +146,7 @@ def is_float_type(type_: astx.DataType | None) -> bool: @public +@typechecked def is_numeric_type(type_: astx.DataType | None) -> bool: """ title: Is numeric type. @@ -152,6 +160,7 @@ def is_numeric_type(type_: astx.DataType | None) -> bool: @public +@typechecked def is_boolean_type(type_: astx.DataType | None) -> bool: """ title: Is boolean type. @@ -165,6 +174,7 @@ def is_boolean_type(type_: astx.DataType | None) -> bool: @public +@typechecked def is_string_type(type_: astx.DataType | None) -> bool: """ title: Is string type. @@ -178,6 +188,7 @@ def is_string_type(type_: astx.DataType | None) -> bool: @public +@typechecked def is_temporal_type(type_: astx.DataType | None) -> bool: """ title: Is temporal type. @@ -191,6 +202,7 @@ def is_temporal_type(type_: astx.DataType | None) -> bool: @public +@typechecked def is_none_type(type_: astx.DataType | None) -> bool: """ title: Is none type. @@ -204,6 +216,7 @@ def is_none_type(type_: astx.DataType | None) -> bool: @public +@typechecked def bit_width(type_: astx.DataType | None) -> int: """ title: Return the nominal bit width for numeric types. @@ -218,6 +231,7 @@ def bit_width(type_: astx.DataType | None) -> int: return _BIT_WIDTHS.get(type(type_), 0) +@typechecked def _type_for_width( width: int, table: dict[int, type[astx.DataType]], @@ -239,6 +253,7 @@ def _type_for_width( @public +@typechecked def float_promotion_width_for_integer_width(width: int) -> int: """ title: Return the float width floor used when integers promote with floats. @@ -255,6 +270,7 @@ def float_promotion_width_for_integer_width(width: int) -> int: return BIT_WIDTH_64 +@typechecked def _common_integer_type( lhs: astx.DataType, rhs: astx.DataType, @@ -295,6 +311,7 @@ def _common_integer_type( ) +@typechecked def _common_float_type( lhs: astx.DataType, rhs: astx.DataType, @@ -325,6 +342,7 @@ def _common_float_type( @public +@typechecked def common_numeric_type( lhs: astx.DataType | None, rhs: astx.DataType | None, @@ -349,6 +367,7 @@ def common_numeric_type( return _common_integer_type(lhs, rhs) +@typechecked def _is_safe_integer_assignment( target: astx.DataType, value: astx.DataType, @@ -375,6 +394,7 @@ def _is_safe_integer_assignment( return False +@typechecked def _is_safe_float_assignment( target: astx.DataType, value: astx.DataType, @@ -400,6 +420,7 @@ def _is_safe_float_assignment( @public +@typechecked def is_explicitly_castable( source: astx.DataType | None, target: astx.DataType | None, @@ -432,6 +453,7 @@ def is_explicitly_castable( @public +@typechecked def is_assignable( target: astx.DataType | None, value: astx.DataType | None, diff --git a/src/irx/analysis/typing.py b/src/irx/analysis/typing.py index 21dac2ff..a694d040 100644 --- a/src/irx/analysis/typing.py +++ b/src/irx/analysis/typing.py @@ -14,8 +14,10 @@ is_numeric_type, is_string_type, ) +from irx.typecheck import typechecked +@typechecked def binary_result_type( op_code: str, lhs_type: astx.DataType | None, @@ -62,6 +64,7 @@ def binary_result_type( return lhs_type if lhs_type == rhs_type else None +@typechecked def unary_result_type( op_code: str, operand_type: astx.DataType | None, diff --git a/src/irx/analysis/validation.py b/src/irx/analysis/validation.py index 362c15bd..27ee861e 100644 --- a/src/irx/analysis/validation.py +++ b/src/irx/analysis/validation.py @@ -16,6 +16,7 @@ is_assignable, is_explicitly_castable, ) +from irx.typecheck import typechecked TIME_PARTS_HOUR_MINUTE = 2 TIME_PARTS_HOUR_MINUTE_SECOND = 3 @@ -25,6 +26,7 @@ INT32_MAX = 2**31 - 1 +@typechecked def validate_assignment( diagnostics: DiagnosticBag, *, @@ -54,6 +56,7 @@ def validate_assignment( ) +@typechecked def validate_call( diagnostics: DiagnosticBag, *, @@ -85,6 +88,7 @@ def validate_call( ) +@typechecked def validate_cast( diagnostics: DiagnosticBag, *, @@ -114,6 +118,7 @@ def validate_cast( ) +@typechecked def validate_literal_time(value: str) -> time: """ title: Validate an astx time literal. @@ -149,6 +154,7 @@ def validate_literal_time(value: str) -> time: return time(hour, minute, second) +@typechecked def validate_literal_timestamp(value: str) -> datetime: """ title: Validate an astx timestamp literal. @@ -168,6 +174,7 @@ def validate_literal_timestamp(value: str) -> datetime: raise ValueError(str(exc)) from exc +@typechecked def validate_literal_datetime(value: str) -> datetime: """ title: Validate an astx datetime literal. @@ -232,6 +239,7 @@ def validate_literal_datetime(value: str) -> datetime: raise ValueError("invalid calendar date/time") from exc +@typechecked def validate_calendar_date(value: str) -> date: """ title: Validate a date component. diff --git a/src/irx/astx/__init__.py b/src/irx/astx/__init__.py index f8d354fb..da735185 100644 --- a/src/irx/astx/__init__.py +++ b/src/irx/astx/__init__.py @@ -71,6 +71,7 @@ ) from irx.astx.system import Cast as Cast from irx.astx.system import PrintExpr as PrintExpr +from irx.typecheck import typechecked __all__ = ( "SPECIALIZED_BINARY_OP_EXTRA", @@ -118,3 +119,7 @@ def __dir__() -> list[str]: type: list[str] """ return sorted(set(dir(_upstream_astx)) | set(__all__)) + + +__getattr__ = typechecked(__getattr__) +__dir__ = typechecked(__dir__) diff --git a/src/irx/astx/binary_op.py b/src/irx/astx/binary_op.py index d14730a2..16eac7dd 100644 --- a/src/irx/astx/binary_op.py +++ b/src/irx/astx/binary_op.py @@ -153,6 +153,7 @@ class BitXorBinOp(astx.BinaryOp): } +@typechecked def binary_op_type_for_opcode(op_code: str) -> type[astx.BinaryOp]: """ title: Return the specialized BinaryOp subclass for an opcode. @@ -165,6 +166,7 @@ def binary_op_type_for_opcode(op_code: str) -> type[astx.BinaryOp]: return _BINARY_OP_TYPES.get(op_code, astx.BinaryOp) +@typechecked def specialize_binary_op(node: astx.BinaryOp) -> astx.BinaryOp: """ title: Return a specialized BinaryOp instance for the given opcode. diff --git a/src/irx/builder/core.py b/src/irx/builder/core.py index b6cfa41c..3f616da8 100644 --- a/src/irx/builder/core.py +++ b/src/irx/builder/core.py @@ -61,6 +61,7 @@ @private +@typechecked def is_unsigned_node(node: astx.AST) -> bool: """ title: Is unsigned node. @@ -75,6 +76,7 @@ def is_unsigned_node(node: astx.AST) -> bool: @private +@typechecked def uses_unsigned_semantics(node: astx.AST) -> bool: """ title: Uses unsigned semantics. @@ -97,6 +99,7 @@ def uses_unsigned_semantics(node: astx.AST) -> bool: @private +@typechecked def semantic_symbol_key(node: astx.AST, fallback: str) -> str: """ title: Semantic symbol key. @@ -117,6 +120,7 @@ def semantic_symbol_key(node: astx.AST, fallback: str) -> str: @private +@typechecked def semantic_assignment_key(node: astx.AST, fallback: str) -> str: """ title: Semantic assignment key. @@ -138,6 +142,7 @@ def semantic_assignment_key(node: astx.AST, fallback: str) -> str: @private +@typechecked def semantic_function_key(node: astx.AST, fallback: str) -> str: """ title: Semantic function key. @@ -158,6 +163,7 @@ def semantic_function_key(node: astx.AST, fallback: str) -> str: @private +@typechecked def semantic_function_name(node: astx.AST, fallback: str) -> str: """ title: Semantic LLVM function name. @@ -179,6 +185,7 @@ def semantic_function_name(node: astx.AST, fallback: str) -> str: @private +@typechecked def semantic_struct_key(node: astx.AST, fallback: str) -> str: """ title: Semantic struct key. @@ -199,6 +206,7 @@ def semantic_struct_key(node: astx.AST, fallback: str) -> str: @private +@typechecked def semantic_struct_name(node: astx.AST, fallback: str) -> str: """ title: Semantic LLVM struct name. @@ -220,6 +228,7 @@ def semantic_struct_name(node: astx.AST, fallback: str) -> str: @private +@typechecked def semantic_flag(node: astx.AST, name: str, default: bool = False) -> bool: """ title: Semantic flag. @@ -241,6 +250,7 @@ def semantic_flag(node: astx.AST, name: str, default: bool = False) -> bool: @private +@typechecked def semantic_fma_rhs(node: astx.AST) -> astx.AST | None: """ title: Semantic fma rhs. diff --git a/src/irx/builder/runtime/arrow/feature.py b/src/irx/builder/runtime/arrow/feature.py index 6efad0a6..dfad324f 100644 --- a/src/irx/builder/runtime/arrow/feature.py +++ b/src/irx/builder/runtime/arrow/feature.py @@ -16,6 +16,7 @@ RuntimeFeature, declare_external_function, ) +from irx.typecheck import typechecked if TYPE_CHECKING: from irx.builder.protocols import VisitorProtocol @@ -23,6 +24,7 @@ IRX_ARROW_TYPE_INT32 = 1 +@typechecked def build_arrow_runtime_feature() -> RuntimeFeature: """ title: Build the Arrow runtime feature specification. @@ -118,6 +120,7 @@ def build_arrow_runtime_feature() -> RuntimeFeature: ) +@typechecked def _declare_builder_int32_new( visitor: VisitorProtocol, ) -> ir.Function: @@ -140,6 +143,7 @@ def _declare_builder_int32_new( ) +@typechecked def _declare_builder_append_int32( visitor: VisitorProtocol, ) -> ir.Function: @@ -165,6 +169,7 @@ def _declare_builder_append_int32( ) +@typechecked def _declare_builder_finish( visitor: VisitorProtocol, ) -> ir.Function: @@ -190,6 +195,7 @@ def _declare_builder_finish( ) +@typechecked def _declare_builder_release( visitor: VisitorProtocol, ) -> ir.Function: @@ -212,6 +218,7 @@ def _declare_builder_release( ) +@typechecked def _declare_array_length( visitor: VisitorProtocol, ) -> ir.Function: @@ -234,6 +241,7 @@ def _declare_array_length( ) +@typechecked def _declare_array_null_count( visitor: VisitorProtocol, ) -> ir.Function: @@ -256,6 +264,7 @@ def _declare_array_null_count( ) +@typechecked def _declare_array_type_id( visitor: VisitorProtocol, ) -> ir.Function: @@ -278,6 +287,7 @@ def _declare_array_type_id( ) +@typechecked def _declare_array_export( visitor: VisitorProtocol, ) -> ir.Function: @@ -305,6 +315,7 @@ def _declare_array_export( ) +@typechecked def _declare_array_import( visitor: VisitorProtocol, ) -> ir.Function: @@ -332,6 +343,7 @@ def _declare_array_import( ) +@typechecked def _declare_array_release( visitor: VisitorProtocol, ) -> ir.Function: @@ -354,6 +366,7 @@ def _declare_array_release( ) +@typechecked def _declare_last_error( visitor: VisitorProtocol, ) -> ir.Function: diff --git a/src/irx/builder/runtime/feature_libc.py b/src/irx/builder/runtime/feature_libc.py index a1168367..659cce27 100644 --- a/src/irx/builder/runtime/feature_libc.py +++ b/src/irx/builder/runtime/feature_libc.py @@ -13,11 +13,13 @@ RuntimeFeature, declare_external_function, ) +from irx.typecheck import typechecked if TYPE_CHECKING: from irx.builder.protocols import VisitorProtocol +@typechecked def build_libc_runtime_feature() -> RuntimeFeature: """ title: Build the libc runtime feature specification. @@ -35,6 +37,7 @@ def build_libc_runtime_feature() -> RuntimeFeature: ) +@typechecked def _declare_exit(visitor: VisitorProtocol) -> ir.Function: """ title: Declare exit. @@ -51,6 +54,7 @@ def _declare_exit(visitor: VisitorProtocol) -> ir.Function: return declare_external_function(visitor._llvm.module, "exit", fn_type) +@typechecked def _declare_malloc(visitor: VisitorProtocol) -> ir.Function: """ title: Declare malloc. @@ -67,6 +71,7 @@ def _declare_malloc(visitor: VisitorProtocol) -> ir.Function: return declare_external_function(visitor._llvm.module, "malloc", fn_type) +@typechecked def _declare_puts(visitor: VisitorProtocol) -> ir.Function: """ title: Declare puts. @@ -83,6 +88,7 @@ def _declare_puts(visitor: VisitorProtocol) -> ir.Function: return declare_external_function(visitor._llvm.module, "puts", fn_type) +@typechecked def _declare_snprintf(visitor: VisitorProtocol) -> ir.Function: """ title: Declare snprintf. diff --git a/src/irx/builder/runtime/features.py b/src/irx/builder/runtime/features.py index 77bb726d..196f2d1b 100644 --- a/src/irx/builder/runtime/features.py +++ b/src/irx/builder/runtime/features.py @@ -88,6 +88,7 @@ class RuntimeFeature: metadata: Mapping[str, object] = field(default_factory=dict) +@typechecked def declare_external_function( module: ir.Module, name: str, fn_type: ir.FunctionType ) -> ir.Function: diff --git a/src/irx/builder/runtime/linking.py b/src/irx/builder/runtime/linking.py index be9f4d27..6ccbdb69 100644 --- a/src/irx/builder/runtime/linking.py +++ b/src/irx/builder/runtime/linking.py @@ -31,6 +31,7 @@ class NativeLinkInputs: linker_flags: tuple[str, ...] +@typechecked def compile_native_artifacts( artifacts: Sequence[NativeArtifact], build_dir: Path, @@ -72,6 +73,7 @@ def compile_native_artifacts( return NativeLinkInputs(tuple(objects), tuple(linker_flags)) +@typechecked def link_executable( primary_object: Path, output_file: Path, @@ -108,6 +110,7 @@ def link_executable( _run_checked(command) +@typechecked def _compile_c_source( artifact: NativeArtifact, build_dir: Path, @@ -145,6 +148,7 @@ def _compile_c_source( return object_path +@typechecked def _run_checked(command: Sequence[str]) -> None: """ title: Run checked. diff --git a/src/irx/builder/runtime/registry.py b/src/irx/builder/runtime/registry.py index 81eea966..15b1070c 100644 --- a/src/irx/builder/runtime/registry.py +++ b/src/irx/builder/runtime/registry.py @@ -231,6 +231,7 @@ def linker_flags(self) -> tuple[str, ...]: @lru_cache(maxsize=1) +@typechecked def get_default_runtime_feature_registry() -> RuntimeFeatureRegistry: """ title: Build the default runtime feature registry. diff --git a/src/irx/builder/types.py b/src/irx/builder/types.py index 739c62db..d104b052 100644 --- a/src/irx/builder/types.py +++ b/src/irx/builder/types.py @@ -15,6 +15,7 @@ from irx.typecheck import typechecked +@typechecked def is_fp_type(type_: ir.Type) -> bool: """ title: Is fp type. @@ -30,6 +31,7 @@ def is_fp_type(type_: ir.Type) -> bool: return isinstance(type_, tuple(fp_types)) +@typechecked def is_int_type(type_: ir.Type) -> bool: """ title: Is int type. diff --git a/src/irx/builder/vector.py b/src/irx/builder/vector.py index 93991ec4..665f41ba 100644 --- a/src/irx/builder/vector.py +++ b/src/irx/builder/vector.py @@ -8,8 +8,10 @@ from llvmlite.ir import VectorType from irx.builder.types import is_fp_type +from irx.typecheck import typechecked +@typechecked def is_vector(value: ir.Value) -> bool: """ title: Is vector. @@ -22,6 +24,7 @@ def is_vector(value: ir.Value) -> bool: return isinstance(getattr(value, "type", None), VectorType) +@typechecked def emit_int_div( ir_builder: ir.IRBuilder, lhs: ir.Value, @@ -49,6 +52,7 @@ def emit_int_div( ) +@typechecked def emit_add( ir_builder: ir.IRBuilder, lhs: ir.Value, @@ -74,6 +78,7 @@ def emit_add( return ir_builder.add(lhs, rhs, name=name) +@typechecked def splat_scalar( ir_builder: ir.IRBuilder, scalar: ir.Value, diff --git a/src/irx/typecheck.py b/src/irx/typecheck.py index a2ca9e6c..bc739e2c 100644 --- a/src/irx/typecheck.py +++ b/src/irx/typecheck.py @@ -16,10 +16,19 @@ _T = TypeVar("_T") +typechecked = _typechecked( + forward_ref_policy=ForwardRefPolicy.IGNORE, + collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS, +) + +global_config.forward_ref_policy = ForwardRefPolicy.IGNORE +global_config.collection_check_strategy = CollectionCheckStrategy.ALL_ITEMS + __all__ = ["copy_type", "skip_unused", "typechecked"] @public +@typechecked def skip_unused(*args: Any, **kwargs: Any) -> None: """ title: Referencing variables to pacify static analyzers. @@ -38,6 +47,7 @@ def skip_unused(*args: Any, **kwargs: Any) -> None: @public +@typechecked def copy_type(f: _T) -> Callable[[Any], _T]: """ title: Copy types for args, kwargs from parent class. @@ -49,12 +59,3 @@ def copy_type(f: _T) -> Callable[[Any], _T]: """ skip_unused(f) return lambda x: x - - -typechecked = _typechecked( - forward_ref_policy=ForwardRefPolicy.IGNORE, - collection_check_strategy=CollectionCheckStrategy.ALL_ITEMS, -) - -global_config.forward_ref_policy = ForwardRefPolicy.IGNORE -global_config.collection_check_strategy = CollectionCheckStrategy.ALL_ITEMS diff --git a/tests/test_typechecked_policy.py b/tests/test_typechecked_policy.py index 125e4b5f..7ec15904 100644 --- a/tests/test_typechecked_policy.py +++ b/tests/test_typechecked_policy.py @@ -30,9 +30,121 @@ def _expr_name(node: ast.expr) -> str: return ast.unparse(node) +def _decorator_names( + node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef, +) -> set[str]: + """ + title: Return normalized decorator names for a function or class. + parameters: + node: + type: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef + returns: + type: set[str] + """ + return {_expr_name(decorator) for decorator in node.decorator_list} + + +def _is_typechecking_stub( + node: ast.FunctionDef | ast.AsyncFunctionDef, +) -> bool: + """ + title: Detect typing-only stub functions kept out of runtime. + parameters: + node: + type: ast.FunctionDef | ast.AsyncFunctionDef + returns: + type: bool + """ + if len(node.body) != 1: + return False + stmt = node.body[0] + return ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and stmt.value.value is Ellipsis + ) + + +def _is_protocol_class(node: ast.ClassDef) -> bool: + """ + title: Return whether a class is a typing Protocol. + parameters: + node: + type: ast.ClassDef + returns: + type: bool + """ + base_names = {_expr_name(base) for base in node.bases} + return "Protocol" in base_names + + +def _rebound_typechecked_names(tree: ast.Module) -> set[str]: + """ + title: Return top-level function names wrapped by assignment. + parameters: + tree: + type: ast.Module + returns: + type: set[str] + """ + rebound: set[str] = set() + + for node in tree.body: + if not isinstance(node, ast.Assign): + continue + if len(node.targets) != 1: + continue + target = node.targets[0] + if not isinstance(target, ast.Name): + continue + if not isinstance(node.value, ast.Call): + continue + if _expr_name(node.value.func) != "typechecked": + continue + if len(node.value.args) != 1: + continue + arg = node.value.args[0] + if isinstance(arg, ast.Name) and arg.id == target.id: + rebound.add(target.id) + + return rebound + + +def test_project_functions_are_typechecked() -> None: + """ + title: Assert module-level project functions use the typechecked decorator. + """ + missing: list[str] = [] + + for path in sorted(SOURCE_ROOT.rglob("*.py")): + tree = ast.parse(path.read_text(), filename=str(path)) + rebound = _rebound_typechecked_names(tree) + for node in tree.body: + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if _is_typechecking_stub(node): + continue + if "typechecked" in _decorator_names(node): + continue + if node.name in rebound: + continue + + missing.append( + f"{path.relative_to(REPO_ROOT)}:{node.lineno} {node.name}" + ) + + assert not missing, ( + "Module-level functions under src/irx must use " + "irx.typecheck.typechecked:\n" + "\n".join(missing) + ) + + def test_concrete_project_classes_are_typechecked() -> None: """ title: Assert concrete project classes use the typechecked decorator. + summary: >- + Concrete classes are required to use the class decorator so their methods + are covered at runtime without repeating typechecked on every method. """ missing: list[str] = [] @@ -42,14 +154,10 @@ def test_concrete_project_classes_are_typechecked() -> None: if not isinstance(node, ast.ClassDef): continue - base_names = {_expr_name(base) for base in node.bases} - if "Protocol" in base_names: + if _is_protocol_class(node): continue - decorator_names = { - _expr_name(decorator) for decorator in node.decorator_list - } - if "typechecked" in decorator_names: + if "typechecked" in _decorator_names(node): continue missing.append(