From 4a6dec004a35aba792a1eaa7f6567bfacc457fc9 Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Wed, 8 Apr 2026 19:47:13 +0000 Subject: [PATCH] feat: Make booleans first-class across semantics and lowering --- docs/semantic-contract.md | 46 ++ src/irx/analysis/contract.py | 1 + src/irx/analysis/handlers/base.py | 204 +++++++ src/irx/analysis/handlers/declarations.py | 195 ++++++- src/irx/analysis/handlers/expressions.py | 72 ++- src/irx/analysis/resolved_nodes.py | 55 ++ src/irx/analysis/types.py | 11 + src/irx/astx/__init__.py | 4 + src/irx/astx/structs.py | 135 +++++ src/irx/builder/core.py | 97 +++- src/irx/builder/lowering/binary_ops.py | 45 +- src/irx/builder/lowering/control_flow.py | 4 + src/irx/builder/lowering/functions.py | 29 +- src/irx/builder/lowering/modules.py | 26 +- src/irx/builder/lowering/variables.py | 57 +- src/irx/builder/protocols.py | 30 +- tests/analysis/test_api.py | 1 + tests/test_struct_definition.py | 653 +++++++++++++++++++--- 18 files changed, 1537 insertions(+), 128 deletions(-) create mode 100644 src/irx/astx/structs.py diff --git a/docs/semantic-contract.md b/docs/semantic-contract.md index f03213a4..124fba13 100644 --- a/docs/semantic-contract.md +++ b/docs/semantic-contract.md @@ -34,6 +34,7 @@ Before lowering starts, IRx guarantees that analyzed nodes may carry - `resolved_imports` - `resolved_operator` - `resolved_assignment` +- `resolved_field_access` - `semantic_flags` - `extras` @@ -78,6 +79,51 @@ Boolean behavior is part of the stable semantic boundary: Lowering should branch directly on the analyzed Boolean `i1` value for control flow instead of inventing zero-comparison truthiness rules during codegen. +## Struct Contract + +Structs are IRx's stable composite storage and ABI foundation. + +- struct names are stable semantic symbols +- field order is exactly declaration order +- field names must be unique within a struct +- field types must resolve semantically before lowering +- field layout must not be implicitly reordered by semantics or lowering +- field access must resolve semantically before codegen and lower by stable + field index +- nested structs by value are allowed when every referenced struct is fully + defined +- direct by-value recursive structs are forbidden +- mutual by-value recursive structs are forbidden +- structs can be passed and returned by value within IRx-defined functions +- emitted LLVM struct types are plain data with no hidden headers, metadata, + tags, or runtime object payloads + +For now, empty structs are rejected explicitly instead of relying on backend- +specific behavior. + +Example scalar wrapper: + +```python +astx.StructDefStmt( + name="ScalarBox", + attributes=[ + astx.VariableDeclaration(name="value", type_=astx.Int32()), + ], +) +``` + +Example nested record: + +```python +astx.StructDefStmt( + name="Descriptor", + attributes=[ + astx.VariableDeclaration(name="point", type_=astx.StructType("Point")), + astx.VariableDeclaration(name="ready", type_=astx.Boolean()), + ], +) +``` + ### Canonical Cast Policy Implicit promotions in variable initializers, assignments, call arguments, and diff --git a/src/irx/analysis/contract.py b/src/irx/analysis/contract.py index 1bc08546..d297c1ff 100644 --- a/src/irx/analysis/contract.py +++ b/src/irx/analysis/contract.py @@ -162,6 +162,7 @@ class SemanticContract: "resolved_imports", "resolved_operator", "resolved_assignment", + "resolved_field_access", "semantic_flags", "extras", ), diff --git a/src/irx/analysis/handlers/base.py b/src/irx/analysis/handlers/base.py index c8634229..d0cab342 100644 --- a/src/irx/analysis/handlers/base.py +++ b/src/irx/analysis/handlers/base.py @@ -19,6 +19,7 @@ from irx.analysis.registry import SemanticRegistry from irx.analysis.resolved_nodes import ( ResolvedAssignment, + ResolvedFieldAccess, ResolvedImportBinding, ResolvedOperator, SemanticFlags, @@ -220,6 +221,86 @@ def _set_assignment( """ raise NotImplementedError + def _set_field_access( + self, + node: astx.AST, + field_access: ResolvedFieldAccess | None, + ) -> None: + """ + title: Attach resolved field access metadata. + parameters: + node: + type: astx.AST + field_access: + type: ResolvedFieldAccess | None + """ + raise NotImplementedError + + def _resolve_struct_from_type( + self, + type_: astx.DataType | None, + *, + node: astx.AST, + unknown_message: str, + ) -> SemanticStruct | None: + """ + title: Resolve one struct-valued type reference. + parameters: + type_: + type: astx.DataType | None + node: + type: astx.AST + unknown_message: + type: str + returns: + type: SemanticStruct | None + """ + raise NotImplementedError + + def _resolve_declared_type( + self, + type_: astx.DataType, + *, + node: astx.AST, + unknown_message: str = "Unknown type '{name}'", + ) -> astx.DataType: + """ + title: Resolve one declared type in place. + parameters: + type_: + type: astx.DataType + node: + type: astx.AST + unknown_message: + type: str + returns: + type: astx.DataType + """ + raise NotImplementedError + + def _root_assignment_symbol( + self, + node: astx.AST | None, + ) -> SemanticSymbol | None: + """ + title: Resolve the root symbol for an assignment target chain. + parameters: + node: + type: astx.AST | None + returns: + type: SemanticSymbol | None + """ + raise NotImplementedError + + def _predeclare_block_structs(self, block: astx.Block) -> None: + """ + title: Predeclare struct definitions in one block. + parameters: + block: + type: astx.Block + """ + raise NotImplementedError + def _current_module_key(self) -> ModuleKey: """ title: Return the current module key. @@ -548,6 +629,129 @@ def _set_assignment( return info.resolved_assignment = ResolvedAssignment(symbol) + def _set_field_access( + self, + node: astx.AST, + field_access: ResolvedFieldAccess | None, + ) -> None: + """ + title: Attach resolved field access metadata. + parameters: + node: + type: astx.AST + field_access: + type: ResolvedFieldAccess | None + """ + self._semantic(node).resolved_field_access = field_access + + def _resolve_struct_from_type( + self, + type_: astx.DataType | None, + *, + node: astx.AST, + unknown_message: str, + ) -> SemanticStruct | None: + """ + title: Resolve one struct-valued type reference. + parameters: + type_: + type: astx.DataType | None + node: + type: astx.AST + unknown_message: + type: str + returns: + type: SemanticStruct | None + """ + if not isinstance(type_, astx.StructType): + return None + + binding = self.bindings.resolve(type_.name) + struct = ( + binding.struct + if binding is not None and binding.kind == "struct" + else None + ) + if struct is None and type_.module_key is not None: + lookup_name = type_.resolved_name or type_.name + struct = self.context.get_struct(type_.module_key, lookup_name) + if struct is None: + self.context.diagnostics.add( + unknown_message.format(name=type_.name), + node=node, + ) + return None + + type_.resolved_name = struct.name + type_.module_key = struct.module_key + type_.qualified_name = struct.qualified_name + self._set_struct(type_, struct) + self._set_type(type_, type_) + return struct + + def _resolve_declared_type( + self, + type_: astx.DataType, + *, + node: astx.AST, + unknown_message: str = "Unknown type '{name}'", + ) -> astx.DataType: + """ + title: Resolve one declared type in place. + parameters: + type_: + type: astx.DataType + node: + type: astx.AST + unknown_message: + type: str + returns: + type: astx.DataType + """ + self._resolve_struct_from_type( + type_, + node=node, + unknown_message=unknown_message, + ) + return type_ + + def _root_assignment_symbol( + self, + node: astx.AST | None, + ) -> SemanticSymbol | None: + """ + title: Resolve the root symbol for an assignment target chain. + parameters: + node: + type: astx.AST | None + returns: + type: SemanticSymbol | None + """ + if node is None: + return None + if isinstance(node, astx.Identifier): + return cast( + SemanticInfo, + getattr(node, "semantic", SemanticInfo()), + ).resolved_symbol + if isinstance(node, astx.FieldAccess): + return self._root_assignment_symbol(node.value) + return None + + def _predeclare_block_structs(self, block: astx.Block) -> None: + """ + title: Predeclare struct definitions in one block. + parameters: + block: + type: astx.Block + """ + for node in block.nodes: + if not isinstance(node, astx.StructDefStmt): + continue + struct = self.registry.register_struct(node) + self.bindings.bind_struct(node.name, struct, node=node) + self._set_struct(node, struct) + def _current_module_key(self) -> ModuleKey: """ title: Return the current module key. diff --git a/src/irx/analysis/handlers/declarations.py b/src/irx/analysis/handlers/declarations.py index cf6b8b5f..4dfda340 100644 --- a/src/irx/analysis/handlers/declarations.py +++ b/src/irx/analysis/handlers/declarations.py @@ -9,17 +9,187 @@ from __future__ import annotations +from dataclasses import replace + from irx import astx from irx.analysis.handlers.base import ( SemanticAnalyzerCore, SemanticVisitorMixinBase, ) +from irx.analysis.resolved_nodes import ( + SemanticFunction, + SemanticStruct, + SemanticStructField, +) +from irx.analysis.types import clone_type from irx.analysis.validation import validate_assignment from irx.typecheck import typechecked +DIRECT_STRUCT_CYCLE_LENGTH = 2 + @typechecked class DeclarationVisitorMixin(SemanticVisitorMixinBase): + def _synchronize_function_signature( + self, + function: SemanticFunction, + prototype: astx.FunctionPrototype, + *, + definition: astx.FunctionDef | None = None, + ) -> SemanticFunction: + """ + title: Synchronize one semantic function with resolved AST types. + parameters: + function: + type: SemanticFunction + prototype: + type: astx.FunctionPrototype + definition: + type: astx.FunctionDef | None + returns: + type: SemanticFunction + """ + updated = replace( + function, + return_type=clone_type(prototype.return_type), + args=tuple( + replace(arg_symbol, type_=clone_type(arg_node.type_)) + for arg_node, arg_symbol in zip( + prototype.args.nodes, + function.args, + ) + ), + prototype=prototype, + definition=( + definition if definition is not None else function.definition + ), + ) + self.context.register_function(updated) + return updated + + def _resolve_struct_fields( + self, + struct: SemanticStruct, + ) -> SemanticStruct: + """ + title: Resolve one struct's ordered field metadata. + parameters: + struct: + type: SemanticStruct + returns: + type: SemanticStruct + """ + seen: set[str] = set() + fields: list[SemanticStructField] = [] + + if len(list(struct.declaration.attributes)) == 0: + self.context.diagnostics.add( + f"Struct '{struct.name}' must declare at least one field", + node=struct.declaration, + ) + + for index, attr in enumerate(struct.declaration.attributes): + if attr.name in seen: + self.context.diagnostics.add( + f"Struct field '{attr.name}' already defined.", + node=attr, + ) + seen.add(attr.name) + self._resolve_declared_type( + attr.type_, + node=attr, + unknown_message="Unknown field type '{name}'", + ) + fields.append( + SemanticStructField( + name=attr.name, + index=index, + type_=clone_type(attr.type_), + declaration=attr, + ) + ) + + updated = replace( + struct, + fields=tuple(fields), + field_indices={field.name: field.index for field in fields}, + ) + self.context.register_struct(updated) + self.bindings.bind_struct( + updated.name, + updated, + node=updated.declaration, + ) + self._set_struct(updated.declaration, updated) + return updated + + def _find_struct_cycle( + self, + root: SemanticStruct, + current: SemanticStruct, + path: tuple[SemanticStruct, ...], + ) -> tuple[SemanticStruct, ...] | None: + """ + title: Find one by-value recursive struct cycle. + parameters: + root: + type: SemanticStruct + current: + type: SemanticStruct + path: + type: tuple[SemanticStruct, Ellipsis] + returns: + type: tuple[SemanticStruct, Ellipsis] | None + """ + seen = {struct.qualified_name for struct in path} + for attr in current.declaration.attributes: + field_struct = self._resolve_struct_from_type( + attr.type_, + node=attr, + unknown_message="Unknown field type '{name}'", + ) + if field_struct is None: + continue + if field_struct.qualified_name == root.qualified_name: + return (*path, field_struct) + if field_struct.qualified_name in seen: + continue + cycle = self._find_struct_cycle( + root, + field_struct, + (*path, field_struct), + ) + if cycle is not None: + return cycle + return None + + def _validate_struct_cycles(self, struct: SemanticStruct) -> None: + """ + title: Reject by-value recursive struct layouts. + parameters: + struct: + type: SemanticStruct + """ + cycle = self._find_struct_cycle(struct, struct, (struct,)) + if cycle is None: + return + + if len(cycle) == DIRECT_STRUCT_CYCLE_LENGTH: + self.context.diagnostics.add( + ( + "direct by-value recursive struct " + f"'{struct.name}' is forbidden" + ), + node=struct.declaration, + ) + return + + cycle_names = " -> ".join(item.name for item in cycle) + self.context.diagnostics.add( + f"mutual by-value recursive structs are forbidden: {cycle_names}", + node=struct.declaration, + ) + @SemanticAnalyzerCore.visit.dispatch def visit(self, module: astx.Module) -> None: """ @@ -40,6 +210,7 @@ def visit(self, block: astx.Block) -> None: type: astx.Block """ self._set_type(block, None) + self._predeclare_block_structs(block) for node in block.nodes: self.visit(node) @@ -51,9 +222,13 @@ def visit(self, node: astx.FunctionPrototype) -> None: node: type: astx.FunctionPrototype """ + for arg in node.args.nodes: + self._resolve_declared_type(arg.type_, node=arg) + self._resolve_declared_type(node.return_type, node=node) function = self.registry.resolve_function(node.name) if function is None: function = self.registry.register_function(node) + function = self._synchronize_function_signature(function, node) self.bindings.bind_function(node.name, function, node=node) self._set_function(node, function) @@ -65,12 +240,20 @@ def visit(self, node: astx.FunctionDef) -> None: node: type: astx.FunctionDef """ + for arg in node.prototype.args.nodes: + self._resolve_declared_type(arg.type_, node=arg) + self._resolve_declared_type(node.prototype.return_type, node=node) function = self.registry.resolve_function(node.name) if function is None: function = self.registry.register_function( node.prototype, definition=node, ) + function = self._synchronize_function_signature( + function, + node.prototype, + definition=node, + ) self.bindings.bind_function(node.name, function, node=node) self._set_function(node.prototype, function) self._set_function(node, function) @@ -101,6 +284,7 @@ def visit(self, node: astx.VariableDeclaration) -> None: node: type: astx.VariableDeclaration """ + self._resolve_declared_type(node.type_, node=node) if node.value is not None and not isinstance( node.value, astx.Undefined ): @@ -128,6 +312,7 @@ def visit(self, node: astx.InlineVariableDeclaration) -> None: node: type: astx.InlineVariableDeclaration """ + self._resolve_declared_type(node.type_, node=node) if node.value is not None and not isinstance( node.value, astx.Undefined ): @@ -157,13 +342,7 @@ def visit(self, node: astx.StructDefStmt) -> None: """ struct = self.registry.register_struct(node) self.bindings.bind_struct(node.name, struct, node=node) + struct = self._resolve_struct_fields(struct) self._set_struct(node, struct) - seen: set[str] = set() - for attr in node.attributes: - if attr.name in seen: - self.context.diagnostics.add( - f"Struct field '{attr.name}' already defined.", - node=attr, - ) - seen.add(attr.name) + self._validate_struct_cycles(struct) self._set_type(node, None) diff --git a/src/irx/analysis/handlers/expressions.py b/src/irx/analysis/handlers/expressions.py index acef8451..8472d6c0 100644 --- a/src/irx/analysis/handlers/expressions.py +++ b/src/irx/analysis/handlers/expressions.py @@ -17,7 +17,7 @@ SemanticVisitorMixinBase, ) from irx.analysis.normalization import normalize_flags, normalize_operator -from irx.analysis.resolved_nodes import SemanticInfo +from irx.analysis.resolved_nodes import ResolvedFieldAccess, SemanticInfo from irx.analysis.types import ( is_boolean_type, is_float_type, @@ -62,6 +62,7 @@ def visit(self, node: astx.Identifier) -> None: ) return self._set_symbol(node, symbol) + self._set_type(node, symbol.type_) @SemanticAnalyzerCore.visit.dispatch def visit(self, node: astx.VariableAssignment) -> None: @@ -160,41 +161,47 @@ def visit(self, node: astx.BinaryOp) -> None: self._semantic(node).extras[SPECIALIZED_BINARY_OP_EXTRA] = specialized if node.op_code == "=": - if not isinstance(node.lhs, astx.Identifier): + if not isinstance(node.lhs, (astx.Identifier, astx.FieldAccess)): self.context.diagnostics.add( - "destination of '=' must be a variable", + "destination of '=' must be a variable or field", node=node, ) return - symbol = self.context.scopes.resolve(node.lhs.name) + symbol = self._root_assignment_symbol(node.lhs) if symbol is None: self.context.diagnostics.add( - "codegen: Invalid lhs variable name", + "destination of '=' must be a variable or field", node=node, ) return if not symbol.is_mutable: self.context.diagnostics.add( - "Cannot assign to " - f"'{node.lhs.name}': declared as constant", + f"Cannot assign to '{symbol.name}': declared as constant", node=node, ) + target_name = ( + node.lhs.name + if isinstance(node.lhs, astx.Identifier) + else node.lhs.field_name + ) + target_type = self._expr_type(node.lhs) validate_assignment( self.context.diagnostics, - target_name=node.lhs.name, - target_type=symbol.type_, + target_name=target_name, + target_type=target_type, value_type=rhs_type, node=node, ) self._set_assignment(node, symbol) - self._set_symbol(node.lhs, symbol) - self._set_type(node, symbol.type_) + if isinstance(node.lhs, astx.Identifier): + self._set_symbol(node.lhs, symbol) + self._set_type(node, target_type) self._set_operator( node, normalize_operator( node.op_code, - result_type=symbol.type_, - lhs_type=symbol.type_, + result_type=target_type, + lhs_type=target_type, rhs_type=rhs_type, flags=flags, ), @@ -274,6 +281,7 @@ def visit(self, node: astx.FunctionCall) -> None: return function = binding.function self._set_function(node, function) + self._set_type(node, function.return_type) validate_call( self.context.diagnostics, function=function, @@ -281,6 +289,44 @@ def visit(self, node: astx.FunctionCall) -> None: node=node, ) + @SemanticAnalyzerCore.visit.dispatch + def visit(self, node: astx.FieldAccess) -> None: + """ + title: Visit FieldAccess nodes. + parameters: + node: + type: astx.FieldAccess + """ + self.visit(node.value) + base_type = self._expr_type(node.value) + struct = self._resolve_struct_from_type( + base_type, + node=node, + unknown_message="field access requires a struct value", + ) + if struct is None: + if not isinstance(base_type, astx.StructType): + self.context.diagnostics.add( + "field access requires a struct value", + node=node, + ) + self._set_type(node, None) + return + + field_index = struct.field_indices.get(node.field_name) + if field_index is None or field_index >= len(struct.fields): + self.context.diagnostics.add( + f"struct '{struct.name}' has no field '{node.field_name}'", + node=node, + ) + self._set_type(node, None) + return + + field = struct.fields[field_index] + self._set_struct(node, struct) + self._set_field_access(node, ResolvedFieldAccess(struct, field)) + self._set_type(node, field.type_) + @SemanticAnalyzerCore.visit.dispatch def visit(self, node: astx.Cast) -> None: """ diff --git a/src/irx/analysis/resolved_nodes.py b/src/irx/analysis/resolved_nodes.py index da8fe9de..19b2dc82 100644 --- a/src/irx/analysis/resolved_nodes.py +++ b/src/irx/analysis/resolved_nodes.py @@ -74,6 +74,10 @@ class SemanticStruct: type: str declaration: type: astx.StructDefStmt + fields: + type: tuple[SemanticStructField, Ellipsis] + field_indices: + type: dict[str, int] """ symbol_id: str @@ -81,6 +85,34 @@ class SemanticStruct: module_key: ModuleKey qualified_name: str declaration: astx.StructDefStmt + fields: tuple["SemanticStructField", ...] = () + field_indices: dict[str, int] = field(default_factory=dict) + + +@public +@typechecked +@dataclass(frozen=True) +class SemanticStructField: + """ + title: Resolved struct field information. + summary: >- + Describe one ordered field within a semantic struct, including its stable + index and resolved field type. + attributes: + name: + type: str + index: + type: int + type_: + type: astx.DataType + declaration: + type: astx.VariableDeclaration + """ + + name: str + index: int + type_: astx.DataType + declaration: astx.VariableDeclaration @public @@ -271,6 +303,26 @@ class ResolvedAssignment: target: SemanticSymbol +@public +@typechecked +@dataclass(frozen=True) +class ResolvedFieldAccess: + """ + title: Resolved field access metadata. + summary: >- + Point from a field-access node to its owning struct and stable field + metadata. + attributes: + struct: + type: SemanticStruct + field: + type: SemanticStructField + """ + + struct: SemanticStruct + field: SemanticStructField + + @public @typechecked @dataclass @@ -297,6 +349,8 @@ class SemanticInfo: type: ResolvedOperator | None resolved_assignment: type: ResolvedAssignment | None + resolved_field_access: + type: ResolvedFieldAccess | None semantic_flags: type: SemanticFlags extras: @@ -311,5 +365,6 @@ class SemanticInfo: resolved_imports: tuple[ResolvedImportBinding, ...] = () resolved_operator: ResolvedOperator | None = None resolved_assignment: ResolvedAssignment | None = None + resolved_field_access: ResolvedFieldAccess | None = None semantic_flags: SemanticFlags = field(default_factory=SemanticFlags) extras: dict[str, Any] = field(default_factory=dict) diff --git a/src/irx/analysis/types.py b/src/irx/analysis/types.py index 479618b7..ccda3265 100644 --- a/src/irx/analysis/types.py +++ b/src/irx/analysis/types.py @@ -68,6 +68,13 @@ def clone_type(type_: astx.DataType) -> astx.DataType: returns: type: astx.DataType """ + if isinstance(type_, astx.StructType): + return astx.StructType( + type_.name, + resolved_name=type_.resolved_name, + module_key=type_.module_key, + qualified_name=type_.qualified_name, + ) return type_.__class__() @@ -86,6 +93,10 @@ def same_type(lhs: astx.DataType | None, rhs: astx.DataType | None) -> bool: """ if lhs is None or rhs is None: return False + if isinstance(lhs, astx.StructType) and isinstance(rhs, astx.StructType): + lhs_identity = lhs.qualified_name or lhs.name + rhs_identity = rhs.qualified_name or rhs.name + return lhs_identity == rhs_identity return lhs.__class__ is rhs.__class__ diff --git a/src/irx/astx/__init__.py b/src/irx/astx/__init__.py index da735185..31b9040a 100644 --- a/src/irx/astx/__init__.py +++ b/src/irx/astx/__init__.py @@ -69,6 +69,8 @@ from irx.astx.binary_op import ( specialize_binary_op as specialize_binary_op, ) +from irx.astx.structs import FieldAccess as FieldAccess +from irx.astx.structs import StructType as StructType from irx.astx.system import Cast as Cast from irx.astx.system import PrintExpr as PrintExpr from irx.typecheck import typechecked @@ -84,6 +86,7 @@ "Cast", "DivBinOp", "EqBinOp", + "FieldAccess", "GeBinOp", "GtBinOp", "LeBinOp", @@ -94,6 +97,7 @@ "MulBinOp", "NeBinOp", "PrintExpr", + "StructType", "SubBinOp", "binary_op_type_for_opcode", "specialize_binary_op", diff --git a/src/irx/astx/structs.py b/src/irx/astx/structs.py new file mode 100644 index 00000000..b8c9b1c4 --- /dev/null +++ b/src/irx/astx/structs.py @@ -0,0 +1,135 @@ +""" +title: IRX-owned struct AST nodes. +""" + +from __future__ import annotations + +import astx + +from astx.types import AnyType + +from irx.typecheck import typechecked + + +@typechecked +class StructType(AnyType): + """ + title: Named struct type reference. + attributes: + name: + type: str + resolved_name: + type: str | None + module_key: + type: str | None + qualified_name: + type: str | None + """ + + name: str + resolved_name: str | None + module_key: str | None + qualified_name: str | None + + def __init__( + self, + name: str, + *, + resolved_name: str | None = None, + module_key: str | None = None, + qualified_name: str | None = None, + ) -> None: + """ + title: Initialize one named struct type reference. + parameters: + name: + type: str + resolved_name: + type: str | None + module_key: + type: str | None + qualified_name: + type: str | None + """ + super().__init__() + self.name = name + self.resolved_name = resolved_name + self.module_key = module_key + self.qualified_name = qualified_name + + def __str__(self) -> str: + """ + title: Render one struct type reference as text. + returns: + type: str + """ + return f"StructType[{self.name}]" + + def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: + """ + title: Build one repr structure for a struct type reference. + parameters: + simplified: + type: bool + returns: + type: astx.base.ReprStruct + """ + key = f"STRUCT-TYPE[{self.name}]" + value = self.qualified_name or self.name + return self._prepare_struct(key, value, simplified) + + +@typechecked +class FieldAccess(astx.DataType): + """ + title: Field access expression. + attributes: + value: + type: astx.AST + field_name: + type: str + type_: + type: AnyType + """ + + value: astx.AST + field_name: str + type_: AnyType + + def __init__(self, value: astx.AST, field_name: str) -> None: + """ + title: Initialize one field access expression. + parameters: + value: + type: astx.AST + field_name: + type: str + """ + super().__init__() + self.value = value + self.field_name = field_name + self.type_ = AnyType() + + def __str__(self) -> str: + """ + title: Render one field access expression as text. + returns: + type: str + """ + return f"FieldAccess[{self.field_name}]" + + def get_struct(self, simplified: bool = False) -> astx.base.ReprStruct: + """ + title: Build one repr structure for a field access expression. + parameters: + simplified: + type: bool + returns: + type: astx.base.ReprStruct + """ + key = f"FIELD-ACCESS[{self.field_name}]" + value = self.value.get_struct(simplified) + return self._prepare_struct(key, value, simplified) + + +__all__ = ["FieldAccess", "StructType"] diff --git a/src/irx/builder/core.py b/src/irx/builder/core.py index 3f616da8..e35109b4 100644 --- a/src/irx/builder/core.py +++ b/src/irx/builder/core.py @@ -26,6 +26,7 @@ from irx.analysis.module_symbols import ( mangle_function_name, mangle_struct_name, + qualified_struct_name, ) from irx.analysis.types import ( bit_width, @@ -632,7 +633,7 @@ def get_function(self, name: str) -> ir.Function | None: def create_entry_block_alloca( self, var_name: str, - type_name: str, + type_name: str | ir.Type, ) -> Any: """ title: Create entry block alloca. @@ -640,17 +641,20 @@ def create_entry_block_alloca( var_name: type: str type_name: - type: str + type: str | ir.Type returns: type: Any """ + llvm_type = ( + self._llvm.get_data_type(type_name) + if isinstance(type_name, str) + else type_name + ) current_block = self._llvm.ir_builder.block self._llvm.ir_builder.position_at_start( self._llvm.ir_builder.function.entry_basic_block ) - alloca = self._llvm.ir_builder.alloca( - self._llvm.get_data_type(type_name), None, var_name - ) + alloca = self._llvm.ir_builder.alloca(llvm_type, None, var_name) if current_block is not None: self._llvm.ir_builder.position_at_end(current_block) return alloca @@ -802,9 +806,92 @@ def _llvm_type_for_ast_type( """ if type_ is None: return None + if isinstance(type_, astx.StructType): + struct_key = type_.qualified_name + if struct_key is None and type_.module_key is not None: + struct_key = qualified_struct_name( + type_.module_key, + type_.resolved_name or type_.name, + ) + if struct_key is None: + return None + return self.struct_types.get(struct_key) type_name = type_.__class__.__name__.lower() return self._llvm.get_data_type(type_name) + def _field_address(self, node: astx.FieldAccess) -> ir.Value: + """ + title: Lower one field-access expression to an address. + parameters: + node: + type: astx.FieldAccess + returns: + type: ir.Value + """ + semantic = getattr(node, "semantic", None) + resolved_field_access = getattr( + semantic, + "resolved_field_access", + None, + ) + if resolved_field_access is None: + raise Exception("codegen: unresolved field access.") + + if isinstance(node.value, astx.Identifier): + base_key = semantic_symbol_key(node.value, node.value.name) + base_ptr = self.named_values.get(base_key) + if base_ptr is None: + raise Exception(f"Unknown variable name: {node.value.name}") + elif isinstance(node.value, astx.FieldAccess): + base_ptr = self._field_address(node.value) + else: + self.visit(node.value) + base_value = safe_pop(self.result_stack) + if base_value is None: + raise Exception("codegen: invalid field access base.") + base_type = self._resolved_ast_type(node.value) + llvm_base_type = self._llvm_type_for_ast_type(base_type) + if llvm_base_type is None: + llvm_base_type = base_value.type + temp_name = f"fieldtmp_{id(node.value)}" + base_ptr = self.create_entry_block_alloca( + temp_name, + llvm_base_type, + ) + self._llvm.ir_builder.store(base_value, base_ptr) + + indices = [ + ir.Constant(self._llvm.INT32_TYPE, 0), + ir.Constant( + self._llvm.INT32_TYPE, + resolved_field_access.field.index, + ), + ] + source_etype = self._llvm_type_for_ast_type( + self._resolved_ast_type(node.value) + ) + if not isinstance(node.value, astx.FieldAccess): + return self._llvm.ir_builder.gep( + base_ptr, + indices, + inbounds=True, + name=f"{resolved_field_access.field.name}_addr", + ) + if source_etype is not None: + typed_ptr = source_etype.as_pointer() + if base_ptr.type != typed_ptr: + base_ptr = self._llvm.ir_builder.bitcast( + base_ptr, + typed_ptr, + name=f"{resolved_field_access.field.name}_baseptr", + ) + return self._llvm.ir_builder.gep( + base_ptr, + indices, + inbounds=True, + name=f"{resolved_field_access.field.name}_addr", + ) + def _bool_value_from_numeric( self, value: ir.Value, diff --git a/src/irx/builder/lowering/binary_ops.py b/src/irx/builder/lowering/binary_ops.py index 2bfb9725..7258ef8d 100644 --- a/src/irx/builder/lowering/binary_ops.py +++ b/src/irx/builder/lowering/binary_ops.py @@ -33,9 +33,9 @@ ) from irx.builder.core import ( VisitorCore, + semantic_assignment_key, semantic_flag, semantic_fma_rhs, - semantic_symbol_key, uses_unsigned_semantics, ) from irx.builder.protocols import VisitorMixinBase @@ -392,11 +392,15 @@ def visit(self, node: AssignmentBinOp) -> None: type: AssignmentBinOp """ var_lhs = node.lhs - if not isinstance(var_lhs, astx.Identifier): - raise Exception("destination of '=' must be a variable") + if not isinstance(var_lhs, (astx.Identifier, astx.FieldAccess)): + raise Exception("destination of '=' must be a variable or field") - lhs_name = var_lhs.name - lhs_key = semantic_symbol_key(var_lhs, lhs_name) + lhs_name = ( + var_lhs.name + if isinstance(var_lhs, astx.Identifier) + else var_lhs.field_name + ) + lhs_key = semantic_assignment_key(node, lhs_name) if lhs_key in self.const_vars: raise Exception( f"Cannot assign to '{lhs_name}': declared as constant" @@ -412,9 +416,34 @@ def visit(self, node: AssignmentBinOp) -> None: target_type=self._resolved_ast_type(node), ) - llvm_lhs = self.named_values.get(lhs_key) - if not llvm_lhs: - raise Exception("codegen: Invalid lhs variable name") + if isinstance(var_lhs, astx.Identifier): + llvm_lhs = self.named_values.get(lhs_key) + if not llvm_lhs: + raise Exception("codegen: Invalid lhs variable name") + else: + if isinstance(var_lhs.value, astx.FieldAccess): + parent_ptr = self._field_address(var_lhs.value) + parent_value = self._llvm.ir_builder.load( + parent_ptr, + f"{var_lhs.field_name}_parent", + ) + resolved = getattr( + getattr(var_lhs, "semantic", None), + "resolved_field_access", + None, + ) + if resolved is None: + raise Exception("codegen: unresolved field access.") + updated_parent = self._llvm.ir_builder.insert_value( + parent_value, + llvm_rhs, + resolved.field.index, + name=f"set_{var_lhs.field_name}", + ) + self._llvm.ir_builder.store(updated_parent, parent_ptr) + self.result_stack.append(llvm_rhs) + return + llvm_lhs = self._field_address(var_lhs) self._llvm.ir_builder.store(llvm_rhs, llvm_lhs) self.result_stack.append(llvm_rhs) diff --git a/src/irx/builder/lowering/control_flow.py b/src/irx/builder/lowering/control_flow.py index e9fb0795..0eb99c1f 100644 --- a/src/irx/builder/lowering/control_flow.py +++ b/src/irx/builder/lowering/control_flow.py @@ -49,6 +49,10 @@ def visit(self, block: astx.Block) -> None: block: type: astx.Block """ + for node in block.nodes: + if isinstance(node, astx.StructDefStmt): + self.visit_child(node) + result = None for node in block.nodes: if self._llvm.ir_builder.block.terminator is not None: diff --git a/src/irx/builder/lowering/functions.py b/src/irx/builder/lowering/functions.py index 64b465f3..d8fbd640 100644 --- a/src/irx/builder/lowering/functions.py +++ b/src/irx/builder/lowering/functions.py @@ -87,11 +87,16 @@ def visit(self, node: astx.FunctionDef) -> None: try: for idx, llvm_arg in enumerate(fn.args): arg_ast = proto.args.nodes[idx] - type_str = arg_ast.type_.__class__.__name__.lower() - arg_type = self._llvm.get_data_type(type_str) symbol_key = semantic_symbol_key(arg_ast, llvm_arg.name) + arg_type = self._llvm_type_for_ast_type(arg_ast.type_) + if arg_type is None: + raise Exception( + "codegen: Unknown LLVM type for function argument " + f"'{llvm_arg.name}'." + ) alloca = self._llvm.ir_builder.alloca( - arg_type, name=llvm_arg.name + arg_type, + name=llvm_arg.name, ) self._llvm.ir_builder.store(llvm_arg, alloca) self.named_values[symbol_key] = alloca @@ -122,12 +127,20 @@ def visit(self, node: astx.FunctionPrototype) -> None: """ args_type = [] for arg in node.args.nodes: - type_str = arg.type_.__class__.__name__.lower() - args_type.append(self._llvm.get_data_type(type_str)) + llvm_type = self._llvm_type_for_ast_type(arg.type_) + if llvm_type is None: + raise Exception( + "codegen: Unknown LLVM type for function argument " + f"'{arg.name}'." + ) + args_type.append(llvm_type) - return_type = self._llvm.get_data_type( - node.return_type.__class__.__name__.lower() - ) + return_type = self._llvm_type_for_ast_type(node.return_type) + if return_type is None: + raise Exception( + "codegen: Unknown LLVM return type for function " + f"'{node.name}'." + ) fn_type = ir.FunctionType(return_type, args_type, False) function_key = semantic_function_key(node, node.name) existing = self.llvm_functions_by_symbol_id.get(function_key) diff --git a/src/irx/builder/lowering/modules.py b/src/irx/builder/lowering/modules.py index b2a97464..762566f0 100644 --- a/src/irx/builder/lowering/modules.py +++ b/src/irx/builder/lowering/modules.py @@ -42,15 +42,29 @@ def visit(self, node: astx.StructDefStmt) -> None: self.struct_types[struct_key] = existing return + field_types: list[ir.Type] = [] + semantic = getattr(node, "semantic", None) + resolved_struct = getattr(semantic, "resolved_struct", None) + fields = ( + resolved_struct.fields + if resolved_struct is not None and resolved_struct.fields + else () + ) + for field in fields: + llvm_type = self._llvm_type_for_ast_type(field.type_) + if llvm_type is None: + raise Exception( + f"codegen: Unknown LLVM type for struct field " + f"'{field.name}'." + ) + field_types.append(llvm_type) + llvm_name = semantic_struct_name(node, node.name) struct_type = self._llvm.module.context.get_identified_type(llvm_name) if not struct_type.is_opaque: - raise ValueError(f"Struct '{node.name}' already defined.") - - field_types: list[ir.Type] = [] - for attr in node.attributes: - type_str = attr.type_.__class__.__name__.lower() - field_types.append(self._llvm.get_data_type(type_str)) + self.struct_types[struct_key] = struct_type + self.llvm_structs_by_qualified_name[struct_key] = struct_type + return struct_type.set_body(*field_types) self.struct_types[struct_key] = struct_type diff --git a/src/irx/builder/lowering/variables.py b/src/irx/builder/lowering/variables.py index 2d63488f..a1fc97e2 100644 --- a/src/irx/builder/lowering/variables.py +++ b/src/irx/builder/lowering/variables.py @@ -70,6 +70,39 @@ def visit(self, node: astx.Identifier) -> None: result = self._llvm.ir_builder.load(expr_var, node.name) self.result_stack.append(result) + @VisitorCore.visit.dispatch + def visit(self, node: astx.FieldAccess) -> None: + """ + title: Visit FieldAccess nodes. + parameters: + node: + type: astx.FieldAccess + """ + if isinstance(node.value, astx.FieldAccess): + parent_ptr = self._field_address(node.value) + parent_value = self._llvm.ir_builder.load( + parent_ptr, + f"{node.field_name}_parent", + ) + resolved = getattr( + getattr(node, "semantic", None), + "resolved_field_access", + None, + ) + if resolved is None: + raise Exception("codegen: unresolved field access.") + result = self._llvm.ir_builder.extract_value( + parent_value, + resolved.field.index, + node.field_name, + ) + self.result_stack.append(result) + return + + field_ptr = self._field_address(node) + result = self._llvm.ir_builder.load(field_ptr, node.field_name) + self.result_stack.append(result) + @VisitorCore.visit.dispatch def visit(self, node: astx.VariableDeclaration) -> None: """ @@ -83,6 +116,11 @@ def visit(self, node: astx.VariableDeclaration) -> None: raise Exception(f"Identifier already declared: {node.name}") type_str = node.type_.__class__.__name__.lower() + llvm_type = self._llvm_type_for_ast_type(node.type_) + if llvm_type is None: + raise Exception( + f"codegen: Unknown LLVM type for variable '{node.name}'." + ) if node.value is not None and not isinstance( node.value, astx.Undefined ): @@ -101,7 +139,7 @@ def visit(self, node: astx.VariableDeclaration) -> None: node.name, "stringascii" ) else: - alloca = self.create_entry_block_alloca(node.name, type_str) + alloca = self.create_entry_block_alloca(node.name, llvm_type) self._llvm.ir_builder.store(init_val, alloca) else: if type_str == "string": @@ -127,12 +165,15 @@ def visit(self, node: astx.VariableDeclaration) -> None: alloca = self.create_entry_block_alloca( node.name, "stringascii" ) + elif isinstance(node.type_, astx.StructType): + init_val = ir.Constant(llvm_type, None) + alloca = self.create_entry_block_alloca(node.name, llvm_type) elif "float" in type_str: init_val = ir.Constant(self._llvm.get_data_type(type_str), 0.0) - alloca = self.create_entry_block_alloca(node.name, type_str) + alloca = self.create_entry_block_alloca(node.name, llvm_type) else: init_val = ir.Constant(self._llvm.get_data_type(type_str), 0) - alloca = self.create_entry_block_alloca(node.name, type_str) + alloca = self.create_entry_block_alloca(node.name, llvm_type) self._llvm.ir_builder.store(init_val, alloca) @@ -153,6 +194,12 @@ def visit(self, node: astx.InlineVariableDeclaration) -> None: raise Exception(f"Identifier already declared: {node.name}") type_str = node.type_.__class__.__name__.lower() + llvm_type = self._llvm_type_for_ast_type(node.type_) + if llvm_type is None: + raise Exception( + "codegen: Unknown LLVM type for inline variable " + f"'{node.name}'." + ) if node.value is not None: self.visit_child(node.value) init_val = safe_pop(self.result_stack) @@ -163,6 +210,8 @@ def visit(self, node: astx.InlineVariableDeclaration) -> None: source_type=self._resolved_ast_type(node.value), target_type=node.type_, ) + elif isinstance(node.type_, astx.StructType): + init_val = ir.Constant(llvm_type, None) elif "float" in type_str: init_val = ir.Constant(self._llvm.get_data_type(type_str), 0.0) else: @@ -171,7 +220,7 @@ def visit(self, node: astx.InlineVariableDeclaration) -> None: if type_str == "string": alloca = self.create_entry_block_alloca(node.name, "stringascii") else: - alloca = self.create_entry_block_alloca(node.name, type_str) + alloca = self.create_entry_block_alloca(node.name, llvm_type) self._llvm.ir_builder.store(init_val, alloca) if node.mutability == astx.MutabilityKind.constant: diff --git a/src/irx/builder/protocols.py b/src/irx/builder/protocols.py index 0a9bd9aa..f02e2b86 100644 --- a/src/irx/builder/protocols.py +++ b/src/irx/builder/protocols.py @@ -93,7 +93,7 @@ def llvm_function_name_for_node( ... def create_entry_block_alloca( - self, _var_name: str, _type_name: str + self, _var_name: str, _type_name: str | ir.Type ) -> Any: """ title: Create entry block alloca. @@ -101,12 +101,23 @@ def create_entry_block_alloca( _var_name: type: str _type_name: - type: str + type: str | ir.Type returns: type: Any """ ... + def _field_address(self, _node: astx.FieldAccess) -> ir.Value: + """ + title: Lower one field access to an address. + parameters: + _node: + type: astx.FieldAccess + returns: + type: ir.Value + """ + ... + def require_runtime_symbol( self, _feature_name: str, _symbol_name: str ) -> ir.Function: @@ -472,7 +483,7 @@ def llvm_function_name_for_node( return _fallback def create_entry_block_alloca( - self, _var_name: str, _type_name: str + self, _var_name: str, _type_name: str | ir.Type ) -> Any: """ title: Create entry block alloca. @@ -480,12 +491,23 @@ def create_entry_block_alloca( _var_name: type: str _type_name: - type: str + type: str | ir.Type returns: type: Any """ return cast(Any, None) + def _field_address(self, _node: astx.FieldAccess) -> ir.Value: + """ + title: Lower one field access to an address. + parameters: + _node: + type: astx.FieldAccess + returns: + type: ir.Value + """ + return cast(ir.Value, None) + def require_runtime_symbol( self, _feature_name: str, _symbol_name: str ) -> ir.Function: diff --git a/tests/analysis/test_api.py b/tests/analysis/test_api.py index 4a4cc70d..2eff6e53 100644 --- a/tests/analysis/test_api.py +++ b/tests/analysis/test_api.py @@ -589,6 +589,7 @@ def test_public_analysis_contract_is_stable() -> None: "resolved_imports", "resolved_operator", "resolved_assignment", + "resolved_field_access", "semantic_flags", "extras", ) diff --git a/tests/test_struct_definition.py b/tests/test_struct_definition.py index 14c30b94..5401d05c 100644 --- a/tests/test_struct_definition.py +++ b/tests/test_struct_definition.py @@ -1,8 +1,9 @@ """ -title: Test Struct Definition -summary: Verify StructDefStmt generates an LLVM identified struct type. +title: Stable struct semantics and lowering tests. """ +from __future__ import annotations + import pytest from irx import astx @@ -11,22 +12,67 @@ from irx.builder import Builder as LLVMBuilder from irx.builder.base import Builder +from tests.conftest import assert_ir_parses, make_module -@pytest.mark.parametrize("builder_class", [LLVMBuilder]) -def test_struct_definition(builder_class: type[Builder]) -> None: + +def _struct_type(name: str) -> astx.StructType: """ - title: Struct definition code generation - summary: Ensure StructDefStmt translates to an LLVM struct type. + title: Build one named struct type reference. parameters: - builder_class: - type: type[Builder] + name: + type: str + returns: + type: astx.StructType """ + return astx.StructType(name) - builder = builder_class() - module = builder.module() - # Define struct: Point { x: int32, y: int32 } - struct_def = astx.StructDefStmt( +def _field(value: astx.AST, name: str) -> astx.FieldAccess: + """ + title: Build one field access node. + parameters: + value: + type: astx.AST + name: + type: str + returns: + type: astx.FieldAccess + """ + return astx.FieldAccess(value, name) + + +def _mutable_var( + name: str, + type_: astx.DataType, + value: astx.DataType | astx.Undefined = astx.Undefined(), +) -> astx.VariableDeclaration: + """ + title: Build one mutable local variable declaration. + parameters: + name: + type: str + type_: + type: astx.DataType + value: + type: astx.DataType | astx.Undefined + returns: + type: astx.VariableDeclaration + """ + return astx.VariableDeclaration( + name=name, + type_=type_, + mutability=astx.MutabilityKind.mutable, + value=value, + ) + + +def _point_struct() -> astx.StructDefStmt: + """ + title: Build a simple point struct. + returns: + type: astx.StructDefStmt + """ + return astx.StructDefStmt( name="Point", attributes=[ astx.VariableDeclaration(name="x", type_=astx.Int32()), @@ -34,119 +80,582 @@ def test_struct_definition(builder_class: type[Builder]) -> None: ], ) - # Define main() -> int32 - main_proto = astx.FunctionPrototype( - name="main", - args=astx.Arguments(), - return_type=astx.Int32(), - ) - - main_block = astx.Block() - main_block.append(struct_def) - main_block.append(astx.FunctionReturn(astx.LiteralInt32(0))) - main_fn = astx.FunctionDef( - prototype=main_proto, - body=main_block, +def _main_int32(*body_nodes: astx.AST) -> astx.FunctionDef: + """ + title: Build a small int32-returning main function. + parameters: + body_nodes: + type: astx.AST + variadic: positional + returns: + type: astx.FunctionDef + """ + body = astx.Block() + for node in body_nodes: + body.append(node) + if not any(isinstance(node, astx.FunctionReturn) for node in body_nodes): + body.append(astx.FunctionReturn(astx.LiteralInt32(0))) + return astx.FunctionDef( + prototype=astx.FunctionPrototype( + name="main", + args=astx.Arguments(), + return_type=astx.Int32(), + ), + body=body, ) - module.block.append(main_fn) + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_struct_definition_simple(builder_class: type[Builder]) -> None: + """ + title: Simple struct definitions lower to identified LLVM structs. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + module = make_module("main", _point_struct(), _main_int32()) ir_text = builder.translate(module) point_name = mangle_struct_name("main", "Point") + assert f'%"{point_name}" = type {{i32, i32}}' in ir_text + assert_ir_parses(ir_text) @pytest.mark.parametrize("builder_class", [LLVMBuilder]) -def test_struct_definition_single_field(builder_class: type[Builder]) -> None: +def test_struct_field_layout_stays_in_declaration_order( + builder_class: type[Builder], +) -> None: """ - title: Single field struct definition - summary: Ensure StructDefStmt works with a single attribute. + title: Field order stays stable in the emitted LLVM type. parameters: builder_class: type: type[Builder] """ - builder = builder_class() - module = builder.module() + record = astx.StructDefStmt( + name="Record", + attributes=[ + astx.VariableDeclaration(name="flag", type_=astx.Boolean()), + astx.VariableDeclaration(name="count", type_=astx.Int32()), + astx.VariableDeclaration(name="weight", type_=astx.Float64()), + ], + ) + module = make_module("main", record, _main_int32()) + + ir_text = builder.translate(module) + + assert '%"main__Record" = type {i1, i32, double}' in ir_text + assert_ir_parses(ir_text) - struct_def = astx.StructDefStmt( - name="Value", + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_nested_struct_layout_is_stable(builder_class: type[Builder]) -> None: + """ + title: Nested structs preserve declaration-order layout transitively. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + inner = astx.StructDefStmt( + name="Inner", attributes=[ - astx.VariableDeclaration(name="x", type_=astx.Int32()), + astx.VariableDeclaration(name="tag", type_=astx.UInt8()), + astx.VariableDeclaration(name="value", type_=astx.Int32()), + ], + ) + outer = astx.StructDefStmt( + name="Outer", + attributes=[ + astx.VariableDeclaration( + name="inner", + type_=_struct_type("Inner"), + ), + astx.VariableDeclaration(name="ready", type_=astx.Boolean()), ], ) + module = make_module("main", inner, outer, _main_int32()) - main_proto = astx.FunctionPrototype( - name="main", - args=astx.Arguments(), - return_type=astx.Int32(), + ir_text = builder.translate(module) + + assert '%"main__Inner" = type {i8, i32}' in ir_text + assert '%"main__Outer" = type {%"main__Inner", i1}' in ir_text + assert_ir_parses(ir_text) + + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_struct_field_reads_and_writes_use_stable_indices( + builder_class: type[Builder], +) -> None: + """ + title: Field access uses declaration-order indices for reads and writes. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + triple = astx.StructDefStmt( + name="Triple", + attributes=[ + astx.VariableDeclaration(name="first", type_=astx.Int32()), + astx.VariableDeclaration(name="middle", type_=astx.Int32()), + astx.VariableDeclaration(name="last", type_=astx.Int32()), + ], ) + main_fn = _main_int32( + _mutable_var("t", _struct_type("Triple")), + astx.BinaryOp( + "=", + _field(astx.Identifier("t"), "first"), + astx.LiteralInt32(1), + ), + astx.BinaryOp( + "=", + _field(astx.Identifier("t"), "middle"), + astx.LiteralInt32(2), + ), + astx.BinaryOp( + "=", + _field(astx.Identifier("t"), "last"), + astx.LiteralInt32(3), + ), + astx.FunctionReturn(_field(astx.Identifier("t"), "last")), + ) + module = make_module("main", triple, main_fn) - main_block = astx.Block() - main_block.append(struct_def) - main_block.append(astx.FunctionReturn(astx.LiteralInt32(0))) + ir_text = builder.translate(module) - main_fn = astx.FunctionDef( - prototype=main_proto, - body=main_block, + gep_first = ( + 'getelementptr inbounds %"main__Triple", ' + '%"main__Triple"* %"t", i32 0, i32 0' + ) + gep_middle = ( + 'getelementptr inbounds %"main__Triple", ' + '%"main__Triple"* %"t", i32 0, i32 1' + ) + gep_last = ( + 'getelementptr inbounds %"main__Triple", ' + '%"main__Triple"* %"t", i32 0, i32 2' ) + assert gep_first in ir_text + assert gep_middle in ir_text + assert gep_last in ir_text + assert_ir_parses(ir_text) - module.block.append(main_fn) + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_nested_field_access_reads_and_writes( + builder_class: type[Builder], +) -> None: + """ + title: Nested field access lowers through semantic field indices. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + inner = astx.StructDefStmt( + name="Payload", + attributes=[ + astx.VariableDeclaration(name="value", type_=astx.Int32()), + ], + ) + outer = astx.StructDefStmt( + name="Container", + attributes=[ + astx.VariableDeclaration( + name="inner", + type_=_struct_type("Payload"), + ), + astx.VariableDeclaration(name="enabled", type_=astx.Boolean()), + ], + ) + nested_value = _field(_field(astx.Identifier("o"), "inner"), "value") + main_fn = _main_int32( + _mutable_var("o", _struct_type("Container")), + astx.BinaryOp("=", nested_value, astx.LiteralInt32(9)), + astx.FunctionReturn( + _field(_field(astx.Identifier("o"), "inner"), "value") + ), + ) + module = make_module("main", inner, outer, main_fn) ir_text = builder.translate(module) - assert ( - f'%"{mangle_struct_name("main", "Value")}" = type {{i32}}' in ir_text + + assert '%"main__Container" = type {%"main__Payload", i1}' in ir_text + gep_inner = ( + 'getelementptr inbounds %"main__Container", ' + '%"main__Container"* %"o", i32 0, i32 0' ) + assert gep_inner in ir_text + assert_ir_parses(ir_text) @pytest.mark.parametrize("builder_class", [LLVMBuilder]) -def test_struct_definition_duplicate_name( +def test_struct_parameter_by_value_is_supported( builder_class: type[Builder], ) -> None: """ - title: Duplicate struct name raises error - summary: >- - Ensure defining a struct with the same name twice raises ValueError. + title: Struct parameters lower by value in IRx-defined functions. parameters: builder_class: type: type[Builder] """ + builder = builder_class() + read_point = astx.FunctionDef( + prototype=astx.FunctionPrototype( + name="read_x", + args=astx.Arguments(astx.Argument("p", _struct_type("Point"))), + return_type=astx.Int32(), + ), + body=astx.Block(), + ) + read_point.body.append( + astx.FunctionReturn(_field(astx.Identifier("p"), "x")) + ) + main_fn = _main_int32( + _mutable_var("p", _struct_type("Point")), + astx.BinaryOp( + "=", + _field(astx.Identifier("p"), "x"), + astx.LiteralInt32(4), + ), + astx.FunctionReturn( + astx.FunctionCall("read_x", [astx.Identifier("p")]) + ), + ) + module = make_module("main", _point_struct(), read_point, main_fn) + ir_text = builder.translate(module) + + assert 'define i32 @"main__read_x"(%"main__Point" %"p")' in ir_text + assert 'call i32 @"main__read_x"(%"main__Point" ' in ir_text + assert_ir_parses(ir_text) + + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_struct_return_by_value_and_assignment_from_call( + builder_class: type[Builder], +) -> None: + """ + title: Struct-returning functions round-trip by value through locals. + parameters: + builder_class: + type: type[Builder] + """ builder = builder_class() - module = builder.module() + make_point_body = astx.Block() + make_point_body.append(_mutable_var("p", _struct_type("Point"))) + make_point_body.append( + astx.BinaryOp( + "=", + _field(astx.Identifier("p"), "x"), + astx.LiteralInt32(11), + ) + ) + make_point_body.append(astx.FunctionReturn(astx.Identifier("p"))) + make_point = astx.FunctionDef( + prototype=astx.FunctionPrototype( + name="make_point", + args=astx.Arguments(), + return_type=_struct_type("Point"), + ), + body=make_point_body, + ) + main_fn = _main_int32( + _mutable_var( + "p", + _struct_type("Point"), + value=astx.FunctionCall("make_point", []), + ), + astx.FunctionReturn(_field(astx.Identifier("p"), "x")), + ) + module = make_module("main", _point_struct(), make_point, main_fn) + + ir_text = builder.translate(module) + + assert 'define %"main__Point" @"main__make_point"()' in ir_text + assert 'call %"main__Point" @"main__make_point"()' in ir_text + assert 'store %"main__Point" %"calltmp", %"main__Point"* %"p"' in ir_text + assert_ir_parses(ir_text) - struct_a = astx.StructDefStmt( - name="Duplicate", + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_nested_struct_return_by_value_is_supported( + builder_class: type[Builder], +) -> None: + """ + title: Nested structs can be returned by value without hidden fields. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + inner = astx.StructDefStmt( + name="Payload", attributes=[ - astx.VariableDeclaration(name="x", type_=astx.Int32()), + astx.VariableDeclaration(name="value", type_=astx.Int32()), ], ) - - struct_b = astx.StructDefStmt( - name="Duplicate", + outer = astx.StructDefStmt( + name="Container", attributes=[ - astx.VariableDeclaration(name="y", type_=astx.Int32()), + astx.VariableDeclaration( + name="inner", + type_=_struct_type("Payload"), + ), + astx.VariableDeclaration(name="ready", type_=astx.Boolean()), ], ) + make_outer_body = astx.Block() + make_outer_body.append(_mutable_var("o", _struct_type("Container"))) + make_outer_body.append( + astx.BinaryOp( + "=", + _field(_field(astx.Identifier("o"), "inner"), "value"), + astx.LiteralInt32(21), + ) + ) + make_outer_body.append(astx.FunctionReturn(astx.Identifier("o"))) + make_outer = astx.FunctionDef( + prototype=astx.FunctionPrototype( + name="make_outer", + args=astx.Arguments(), + return_type=_struct_type("Container"), + ), + body=make_outer_body, + ) + main_fn = _main_int32( + _mutable_var( + "o", + _struct_type("Container"), + value=astx.FunctionCall("make_outer", []), + ), + astx.FunctionReturn( + _field(_field(astx.Identifier("o"), "inner"), "value") + ), + ) + module = make_module("main", inner, outer, make_outer, main_fn) + + ir_text = builder.translate(module) + + assert '%"main__Container" = type {%"main__Payload", i1}' in ir_text + assert 'define %"main__Container" @"main__make_outer"()' in ir_text + assert_ir_parses(ir_text) - main_proto = astx.FunctionPrototype( - name="main", - args=astx.Arguments(), - return_type=astx.Int32(), + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_duplicate_struct_name_is_rejected( + builder_class: type[Builder], +) -> None: + """ + title: Duplicate struct names are semantic errors. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + module = make_module( + "main", + astx.StructDefStmt( + name="Duplicate", + attributes=[ + astx.VariableDeclaration(name="x", type_=astx.Int32()) + ], + ), + astx.StructDefStmt( + name="Duplicate", + attributes=[ + astx.VariableDeclaration(name="y", type_=astx.Int32()) + ], + ), + _main_int32(), ) - main_block = astx.Block() - main_block.append(struct_a) - main_block.append(struct_b) - main_block.append(astx.FunctionReturn(astx.LiteralInt32(0))) + with pytest.raises( + SemanticError, + match=r"Struct 'Duplicate' already defined\.", + ): + builder.translate(module) + - main_fn = astx.FunctionDef( - prototype=main_proto, - body=main_block, +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_duplicate_field_name_is_rejected( + builder_class: type[Builder], +) -> None: + """ + title: Duplicate field names are semantic errors. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + module = make_module( + "main", + astx.StructDefStmt( + name="Broken", + attributes=[ + astx.VariableDeclaration(name="x", type_=astx.Int32()), + astx.VariableDeclaration(name="x", type_=astx.Int32()), + ], + ), + _main_int32(), ) - module.block.append(main_fn) + with pytest.raises( + SemanticError, + match=r"Struct field 'x' already defined\.", + ): + builder.translate(module) + + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_unknown_field_type_is_rejected(builder_class: type[Builder]) -> None: + """ + title: Unknown field types are semantic errors. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + module = make_module( + "main", + astx.StructDefStmt( + name="Broken", + attributes=[ + astx.VariableDeclaration( + name="child", + type_=_struct_type("Missing"), + ), + ], + ), + _main_int32(), + ) + + with pytest.raises(SemanticError, match="Unknown field type 'Missing'"): + builder.translate(module) + + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_empty_structs_are_rejected(builder_class: type[Builder]) -> None: + """ + title: Empty structs are explicitly unsupported for now. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + module = make_module( + "main", + astx.StructDefStmt(name="Empty", attributes=[]), + _main_int32(), + ) + + with pytest.raises( + SemanticError, + match="Struct 'Empty' must declare at least one field", + ): + builder.translate(module) + + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_invalid_field_access_is_rejected( + builder_class: type[Builder], +) -> None: + """ + title: Unknown fields are rejected before lowering. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + main_fn = _main_int32( + _mutable_var("p", _struct_type("Point")), + astx.FunctionReturn(_field(astx.Identifier("p"), "missing")), + ) + module = make_module("main", _point_struct(), main_fn) + + with pytest.raises( + SemanticError, + match="struct 'Point' has no field 'missing'", + ): + builder.translate(module) + + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_direct_recursive_struct_is_rejected( + builder_class: type[Builder], +) -> None: + """ + title: Direct by-value recursive structs are forbidden. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + module = make_module( + "main", + astx.StructDefStmt( + name="Node", + attributes=[ + astx.VariableDeclaration( + name="next", + type_=_struct_type("Node"), + ), + ], + ), + _main_int32(), + ) + + with pytest.raises( + SemanticError, + match="direct by-value recursive struct 'Node' is forbidden", + ): + builder.translate(module) + + +@pytest.mark.parametrize("builder_class", [LLVMBuilder]) +def test_mutual_recursive_structs_are_rejected( + builder_class: type[Builder], +) -> None: + """ + title: Mutual by-value recursive structs are forbidden. + parameters: + builder_class: + type: type[Builder] + """ + builder = builder_class() + module = make_module( + "main", + astx.StructDefStmt( + name="Left", + attributes=[ + astx.VariableDeclaration( + name="right", + type_=_struct_type("Right"), + ), + ], + ), + astx.StructDefStmt( + name="Right", + attributes=[ + astx.VariableDeclaration( + name="left", + type_=_struct_type("Left"), + ), + ], + ), + _main_int32(), + ) - with pytest.raises(SemanticError, match="already defined"): + with pytest.raises( + SemanticError, + match=( + "mutual by-value recursive structs are forbidden: " + "Left -> Right -> Left" + ), + ): builder.translate(module)