From 6aba9798d5254d437e595d15e9a50d8967801882 Mon Sep 17 00:00:00 2001 From: Georgii Plotnikov Date: Mon, 15 Dec 2025 13:11:11 +0900 Subject: [PATCH 1/4] Add `HIR` --- Cargo.toml | 3 +- core/ast/src/{node.rs => _node.rs} | 0 core/ast/src/builder.rs | 31 +-- core/ast/src/lib.rs | 5 +- core/ast/src/node_kind.rs | 51 ---- core/ast/src/nodes.rs | 71 ++---- core/ast/src/nodes_impl.rs | 101 +------- core/ast/src/t_ast.rs | 61 ----- core/hir/Cargo.toml | 11 + core/hir/src/arena.rs | 2 + core/hir/src/hir.rs | 18 ++ core/hir/src/lib.rs | 8 + core/hir/src/nodes.rs | 325 ++++++++++++++++++++++++ core/hir/src/nodes_impl.rs | 62 +++++ core/{ast => hir}/src/type_infer.rs | 88 +++---- core/{ast => hir}/src/type_inference.rs | 3 +- core/{ast => hir}/src/type_info.rs | 16 +- core/wasm-codegen/Cargo.toml | 2 +- core/wasm-codegen/src/compiler.rs | 2 +- 19 files changed, 517 insertions(+), 343 deletions(-) rename core/ast/src/{node.rs => _node.rs} (100%) delete mode 100644 core/ast/src/node_kind.rs delete mode 100644 core/ast/src/t_ast.rs create mode 100644 core/hir/Cargo.toml create mode 100644 core/hir/src/arena.rs create mode 100644 core/hir/src/hir.rs create mode 100644 core/hir/src/lib.rs create mode 100644 core/hir/src/nodes.rs create mode 100644 core/hir/src/nodes_impl.rs rename core/{ast => hir}/src/type_infer.rs (94%) rename core/{ast => hir}/src/type_inference.rs (99%) rename core/{ast => hir}/src/type_info.rs (94%) diff --git a/Cargo.toml b/Cargo.toml index 13330c49..26079e8a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ repository = "https://github.com/Inferara/inference" # Core Inference crates inference = { path = "./core/inference", version = "0.0.1" } inference-ast = { path = "./core/ast", version = "0.0.1" } +inference-hir = { path = "./core/hir", version = "0.0.1" } inference-cli = { path = "./core/cli", version = "0.0.1" } inference-wasm-to-v-translator = { path = "./core/wasm-to-v", version = "0.0.1" } inference-wasm-codegen = { path = "./core/wasm-codegen", version = "0.0.1" } @@ -34,7 +35,7 @@ wat-fmt = { path = "./tools/wat-fmt", version = "0.0.9" } inf-wast = { path = "./tools/inf-wast", version = "0.0.9" } inf-wasmparser = { path = "./tools/inf-wasmparser", version = "0.0.9" } -tree-sitter = "0.26.2" +tree-sitter = "0.26.3" tree-sitter-inference = "0.0.37" anyhow = "1.0.100" serde = { version = "1.0.228", features = ["derive", "rc"] } diff --git a/core/ast/src/node.rs b/core/ast/src/_node.rs similarity index 100% rename from core/ast/src/node.rs rename to core/ast/src/_node.rs diff --git a/core/ast/src/builder.rs b/core/ast/src/builder.rs index ce83aeb4..f05e6069 100644 --- a/core/ast/src/builder.rs +++ b/core/ast/src/builder.rs @@ -4,8 +4,6 @@ use crate::nodes::{ ArgumentType, Directive, IgnoreArgument, Misc, SelfReference, StructExpression, TypeMemberAccessExpression, }; -use crate::type_infer::TypeChecker; -use crate::type_info::TypeInfo; use crate::{ arena::Arena, nodes::{ @@ -20,7 +18,6 @@ use crate::{ TypeQualifiedName, UnaryOperatorKind, UnitLiteral, UseDirective, UzumakiExpression, VariableDefinitionStatement, }, - t_ast::TypedAst, }; use tree_sitter::Node; @@ -40,7 +37,6 @@ pub type CompletedBuilder<'a> = Builder<'a, CompleteState>; pub struct Builder<'a, S> { arena: Arena, source_code: Vec<(Node<'a>, &'a [u8])>, - t_ast: Option, _state: PhantomData, } @@ -56,7 +52,6 @@ impl<'a> Builder<'a, InitState> { Self { arena: Arena::default(), source_code: Vec::new(), - t_ast: None, _state: PhantomData, } } @@ -109,20 +104,9 @@ impl<'a> Builder<'a, InitState> { } res.push(ast); } - let mut type_checker = TypeChecker::new(); - let _ = type_checker.infer_types(&mut res); - - // let mut type_checker = TypeChecker::new(); - let t_ast = TypedAst::new(res, self.arena.clone()); - t_ast.infer_expression_types(); - // run type inference over all expressions - // type_checker - // .infer_types(&t_ast.source_files) - // .map_err(|e| anyhow::Error::msg(format!("Type error: {e:?}")))?; Ok(Builder { arena: Arena::default(), source_code: Vec::new(), - t_ast: Some(t_ast), _state: PhantomData, }) } @@ -1260,9 +1244,6 @@ impl<'a> Builder<'a, InitState> { node.utf8_text(code).unwrap().to_string() }; let node = Rc::new(SimpleType::new(id, location, name)); - node.type_info - .borrow_mut() - .replace(TypeInfo::new(&Type::Simple(node.clone()))); self.arena.add_node( AstNode::Expression(Expression::Type(Type::Simple(node.clone()))), parent_id, @@ -1418,14 +1399,4 @@ impl<'a> Builder<'a, InitState> { } } -impl Builder<'_, CompleteState> { - /// Returns typed AST - /// - /// # Panics - /// - /// This function will panic if resulted `TypedAst` is `None` which means an error occured during the parsing process. - #[must_use] - pub fn t_ast(self) -> TypedAst { - self.t_ast.unwrap() - } -} +impl Builder<'_, CompleteState> {} diff --git a/core/ast/src/lib.rs b/core/ast/src/lib.rs index fdaed474..fa3abec1 100644 --- a/core/ast/src/lib.rs +++ b/core/ast/src/lib.rs @@ -1,9 +1,6 @@ #![warn(clippy::pedantic)] -pub(crate) mod arena; +pub mod arena; pub mod builder; pub(crate) mod enums_impl; pub mod nodes; pub(crate) mod nodes_impl; -pub mod t_ast; -pub mod type_infer; -pub mod type_info; diff --git a/core/ast/src/node_kind.rs b/core/ast/src/node_kind.rs deleted file mode 100644 index fd7ba45e..00000000 --- a/core/ast/src/node_kind.rs +++ /dev/null @@ -1,51 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::{ - node::Location, - types::{Definition, Expression, Literal, Statement, Type}, -}; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub enum NodeKind { - Definition(Definition), - Statement(Statement), - Expression(Expression), - Literal(Literal), - Type(Type), -} - -impl NodeKind { - #[must_use] - pub fn id(&self) -> u32 { - match self { - NodeKind::Definition(d) => d.id(), - NodeKind::Statement(s) => s.id(), - NodeKind::Expression(e) => e.id(), - NodeKind::Literal(l) => l.id(), - NodeKind::Type(t) => t.id(), - } - } - - #[must_use] - #[allow(clippy::cast_possible_truncation)] - pub fn location(&self) -> Location { - match self { - NodeKind::Definition(d) => d.location().clone(), - NodeKind::Statement(s) => s.location(), - NodeKind::Expression(e) => e.location(), - NodeKind::Literal(l) => l.location(), - NodeKind::Type(t) => t.location(), - } - } - - #[must_use] - pub fn children(&self) -> Vec { - match self { - NodeKind::Definition(definition) => definition.children(), - NodeKind::Statement(statement) => statement.children(), - NodeKind::Expression(expression) => expression.children(), - NodeKind::Literal(literal) => literal.children(), - NodeKind::Type(ty) => ty.children(), - } - } -} diff --git a/core/ast/src/nodes.rs b/core/ast/src/nodes.rs index bd545afe..1422db17 100644 --- a/core/ast/src/nodes.rs +++ b/core/ast/src/nodes.rs @@ -5,8 +5,6 @@ use std::{ rc::Rc, }; -use crate::type_info::TypeInfo; - #[derive(Clone, PartialEq, Eq, Debug, Default)] pub struct Location { pub offset_start: u32, @@ -46,7 +44,13 @@ impl Display for Location { write!( f, "Location {{ offset_start: {}, offset_end: {}, start_line: {}, start_column: {}, end_line: {}, end_column: {}, source: {} }}", - self.offset_start, self.offset_end, self.start_line, self.start_column, self.end_line, self.end_column, self.source + self.offset_start, + self.offset_end, + self.start_line, + self.start_column, + self.end_line, + self.end_column, + self.source ) } } @@ -338,8 +342,7 @@ ast_nodes! { } pub struct Identifier { - pub name: String, - pub type_info: RefCell> //TODO revisit + pub name: String } pub struct ConstantDefinition { @@ -410,58 +413,50 @@ ast_nodes! { pub name: Rc, pub ty: Type, pub value: Option>, - pub is_uzumaki: bool, + pub is_uzumaki: bool } pub struct TypeDefinitionStatement { pub name: Rc, - pub ty: Type, + pub ty: Type } pub struct AssignStatement { pub left: RefCell, - pub right: RefCell, + pub right: RefCell } pub struct ArrayIndexAccessExpression { pub array: RefCell, - pub index: RefCell, - pub type_info: RefCell> + pub index: RefCell } pub struct MemberAccessExpression { pub expression: RefCell, - pub name: Rc, - pub type_info: RefCell> + pub name: Rc } pub struct TypeMemberAccessExpression { pub expression: RefCell, - pub name: Rc, - pub type_info: RefCell> + pub name: Rc } pub struct FunctionCallExpression { pub function: Expression, pub type_parameters: Option>>, - pub arguments: Option>, RefCell)>>, - pub type_info: RefCell> + pub arguments: Option>, RefCell)>> } pub struct StructExpression { pub name: Rc, - pub fields: Option, RefCell)>>, - pub type_info: RefCell> + pub fields: Option, RefCell)>> } - pub struct UzumakiExpression { - pub type_info: RefCell> - } + pub struct UzumakiExpression {} pub struct PrefixUnaryExpression { pub expression: RefCell, - pub operator: UnaryOperatorKind, - pub type_info: RefCell> + pub operator: UnaryOperatorKind } pub struct AssertStatement { @@ -469,20 +464,17 @@ ast_nodes! { } pub struct ParenthesizedExpression { - pub expression: RefCell, - pub type_info: RefCell> + pub expression: RefCell } pub struct BinaryExpression { pub left: RefCell, pub operator: OperatorKind, - pub right: RefCell, - pub type_info: RefCell> + pub right: RefCell } pub struct ArrayLiteral { - pub elements: Option>>, - pub type_info: RefCell> + pub elements: Option>> } pub struct BoolLiteral { @@ -494,46 +486,39 @@ ast_nodes! { } pub struct NumberLiteral { - pub value: String, - pub type_info: RefCell> + pub value: String } pub struct UnitLiteral { } pub struct SimpleType { - pub name: String, - pub type_info: RefCell> + pub name: String } pub struct GenericType { pub base: Rc, - pub parameters: Vec>, - pub type_info: RefCell> + pub parameters: Vec> } pub struct FunctionType { pub parameters: Option>, - pub returns: Option, - pub type_info: RefCell> + pub returns: Option } pub struct QualifiedName { pub qualifier: Rc, - pub name: Rc, - pub type_info: RefCell> + pub name: Rc } pub struct TypeQualifiedName { pub alias: Rc, - pub name: Rc, - pub type_info: RefCell> + pub name: Rc } pub struct TypeArray { pub element_type: Type, - pub size: Option, - pub type_info: RefCell> + pub size: Option } } diff --git a/core/ast/src/nodes_impl.rs b/core/ast/src/nodes_impl.rs index 6655fa37..2247f8f9 100644 --- a/core/ast/src/nodes_impl.rs +++ b/core/ast/src/nodes_impl.rs @@ -1,10 +1,7 @@ use std::{cell::RefCell, rc::Rc}; -use crate::{ - nodes::{ - ArgumentType, IgnoreArgument, SelfReference, StructExpression, TypeMemberAccessExpression, - }, - type_info::{NumberTypeKindNumberType, TypeInfo, TypeInfoKind}, +use crate::nodes::{ + ArgumentType, IgnoreArgument, SelfReference, StructExpression, TypeMemberAccessExpression, }; use super::nodes::{ @@ -180,51 +177,12 @@ impl Statement { } impl Expression { - #[must_use] - pub fn type_info(&self) -> Option { - match self { - Expression::ArrayIndexAccess(e) => e.type_info.borrow().clone(), - Expression::MemberAccess(e) => e.type_info.borrow().clone(), - Expression::TypeMemberAccess(e) => e.type_info.borrow().clone(), - Expression::FunctionCall(e) => e.type_info.borrow().clone(), - Expression::Struct(e) => e.type_info.borrow().clone(), - Expression::PrefixUnary(e) => e.type_info.borrow().clone(), - Expression::Parenthesized(e) => e.type_info.borrow().clone(), - Expression::Binary(e) => e.type_info.borrow().clone(), - Expression::Literal(l) => l.type_info(), - Expression::Identifier(e) => e.type_info.borrow().clone(), - Expression::Type(e) => Some(TypeInfo::new(e)), - Expression::Uzumaki(e) => e.type_info.borrow().clone(), - } - } #[must_use] pub fn is_non_det(&self) -> bool { matches!(self, Expression::Uzumaki(_)) } } -impl Literal { - #[must_use] - pub fn type_info(&self) -> Option { - match self { - Literal::Bool(_) => Some(TypeInfo { - kind: TypeInfoKind::Bool, - type_params: vec![], - }), - Literal::Number(literal) => literal.type_info.borrow().clone(), - Literal::String(_) => Some(TypeInfo { - kind: TypeInfoKind::String, - type_params: vec![], - }), - Literal::Unit(_) => Some(TypeInfo { - kind: TypeInfoKind::Unit, - type_params: vec![], - }), - Literal::Array(literal) => literal.type_info.borrow().clone(), - } - } -} - impl UseDirective { #[must_use] pub fn new( @@ -327,12 +285,7 @@ impl EnumDefinition { impl Identifier { #[must_use] pub fn new(id: u32, name: String, location: Location) -> Self { - Identifier { - id, - location, - name, - type_info: RefCell::new(None), - } + Identifier { id, location, name } } #[must_use] @@ -629,7 +582,6 @@ impl ArrayIndexAccessExpression { location, array: RefCell::new(array), index: RefCell::new(index), - type_info: RefCell::new(None), } } } @@ -642,7 +594,6 @@ impl MemberAccessExpression { location, expression: RefCell::new(expression), name, - type_info: RefCell::new(None), } } } @@ -660,7 +611,6 @@ impl TypeMemberAccessExpression { location, expression: RefCell::new(type_expression), name, - type_info: RefCell::new(None), } } } @@ -685,7 +635,6 @@ impl FunctionCallExpression { function, type_parameters, arguments, - type_info: RefCell::new(None), } } @@ -719,7 +668,6 @@ impl StructExpression { location, name, fields, - type_info: RefCell::new(None), } } @@ -742,7 +690,6 @@ impl PrefixUnaryExpression { location, expression: RefCell::new(expression), operator, - type_info: RefCell::new(None), } } } @@ -750,31 +697,7 @@ impl PrefixUnaryExpression { impl UzumakiExpression { #[must_use] pub fn new(id: u32, location: Location) -> Self { - UzumakiExpression { - id, - location, - type_info: RefCell::new(None), - } - } - #[must_use] - pub fn is_i32(&self) -> bool { - if let Some(type_info) = self.type_info.borrow().as_ref() { - return matches!( - type_info.kind, - TypeInfoKind::Number(NumberTypeKindNumberType::I32) - ); - } - false - } - #[must_use] - pub fn is_i64(&self) -> bool { - if let Some(type_info) = self.type_info.borrow().as_ref() { - return matches!( - type_info.kind, - TypeInfoKind::Number(NumberTypeKindNumberType::I64) - ); - } - false + UzumakiExpression { id, location } } } @@ -796,7 +719,6 @@ impl ParenthesizedExpression { id, location, expression: RefCell::new(expression), - type_info: RefCell::new(None), } } } @@ -816,7 +738,6 @@ impl BinaryExpression { left: RefCell::new(left), operator, right: RefCell::new(right), - type_info: RefCell::new(None), } } } @@ -839,7 +760,6 @@ impl ArrayLiteral { id, location, elements: elements.map(|vec| vec.into_iter().map(RefCell::new).collect()), - type_info: RefCell::new(None), } } } @@ -862,7 +782,6 @@ impl NumberLiteral { id, location, value, - type_info: RefCell::new(None), } } } @@ -877,12 +796,7 @@ impl UnitLiteral { impl SimpleType { #[must_use] pub fn new(id: u32, location: Location, name: String) -> Self { - SimpleType { - id, - location, - name, - type_info: RefCell::new(None), - } + SimpleType { id, location, name } } } @@ -899,7 +813,6 @@ impl GenericType { location, base, parameters, - type_info: RefCell::new(None), } } } @@ -917,7 +830,6 @@ impl FunctionType { location, parameters, returns, - type_info: RefCell::new(None), } } } @@ -935,7 +847,6 @@ impl QualifiedName { location, qualifier, name, - type_info: RefCell::new(None), } } @@ -958,7 +869,6 @@ impl TypeQualifiedName { location, alias, name, - type_info: RefCell::new(None), } } @@ -981,7 +891,6 @@ impl TypeArray { location, element_type, size, - type_info: RefCell::new(None), } } } diff --git a/core/ast/src/t_ast.rs b/core/ast/src/t_ast.rs deleted file mode 100644 index 0caab902..00000000 --- a/core/ast/src/t_ast.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::{ - arena::Arena, - nodes::{AstNode, Definition, Expression, SourceFile, Statement}, - type_info::TypeInfo, -}; - -#[derive(Clone, Default)] -pub struct TypedAst { - pub source_files: Vec, - arena: Arena, -} - -impl TypedAst { - #[must_use] - pub fn new(source_files: Vec, arena: Arena) -> Self { - Self { - source_files, - arena, - } - } - - pub fn filter_nodes bool>(&self, fn_predicate: T) -> Vec { - self.arena - .nodes - .values() - .filter(|node| fn_predicate(node)) - .cloned() - .collect() - } - - pub fn infer_expression_types(&self) { - //FIXME: very hacky way to infer Uzumaki expression types in return statements - for function_def_node in - self.filter_nodes(|node| matches!(node, AstNode::Definition(Definition::Function(_)))) - { - let AstNode::Definition(Definition::Function(function_def)) = function_def_node else { - unreachable!() - }; - if function_def.is_void() { - continue; - } - if let Some(Statement::Return(last_stmt)) = function_def.body.statements().last() { - if !matches!(*last_stmt.expression.borrow(), Expression::Uzumaki(_)) { - continue; - } - - match &*last_stmt.expression.borrow() { - Expression::Uzumaki(expr) => { - if expr.type_info.borrow().is_some() { - continue; - } - if let Some(return_type) = &function_def.returns { - expr.type_info.replace(Some(TypeInfo::new(return_type))); - } - } - _ => unreachable!(), - } - } - } - } -} diff --git a/core/hir/Cargo.toml b/core/hir/Cargo.toml new file mode 100644 index 00000000..2e146811 --- /dev/null +++ b/core/hir/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "inference-hir" +version = { workspace = true } +edition = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +inference-ast.workspace = true +anyhow.workspace = true \ No newline at end of file diff --git a/core/hir/src/arena.rs b/core/hir/src/arena.rs new file mode 100644 index 00000000..3072e9fe --- /dev/null +++ b/core/hir/src/arena.rs @@ -0,0 +1,2 @@ +#[derive(Default, Clone)] +pub struct Arena {} diff --git a/core/hir/src/hir.rs b/core/hir/src/hir.rs new file mode 100644 index 00000000..8fc4316d --- /dev/null +++ b/core/hir/src/hir.rs @@ -0,0 +1,18 @@ +use crate::{arena::Arena, type_infer::TypeChecker}; +use inference_ast::arena::Arena as AstArena; + +#[derive(Clone, Default)] +pub struct Hir { + pub arena: Arena, + pub type_checker: TypeChecker, +} + +impl Hir { + #[must_use] + pub fn new(arena: AstArena) -> Self { + Self { + arena: Arena::default(), + type_checker: TypeChecker::default(), + } + } +} diff --git a/core/hir/src/lib.rs b/core/hir/src/lib.rs new file mode 100644 index 00000000..c6d16d5f --- /dev/null +++ b/core/hir/src/lib.rs @@ -0,0 +1,8 @@ +#![warn(clippy::pedantic)] +pub mod hir; +pub mod nodes; +mod nodes_impl; +pub mod type_infer; +// mod type_inference; +pub mod arena; +pub mod type_info; diff --git a/core/hir/src/nodes.rs b/core/hir/src/nodes.rs new file mode 100644 index 00000000..e7db761a --- /dev/null +++ b/core/hir/src/nodes.rs @@ -0,0 +1,325 @@ +use std::rc::Rc; + +use crate::type_info::TypeInfo; + +pub enum Directive { + Use(Rc), +} + +pub enum Definition { + Spec(Rc), + Struct(Rc), + Enum(Rc), + Constant(Rc), + Function(Rc), + ExternalFunction(Rc), + Type(Rc), +} + +pub enum BlockType { + Block(Rc), + Assume(Rc), + Forall(Rc), + Exists(Rc), + Unique(Rc), +} + +pub enum Statement { + Block(BlockType), + Expression(Expression), + Assign(Rc), + Return(Rc), + Loop(Rc), + Break(Rc), + If(Rc), + VariableDefinition(Rc), + TypeDefinition(Rc), + Assert(Rc), + ConstantDefinition(Rc), +} + +pub enum Expression { + ArrayIndexAccess(Rc), + Binary(Rc), + MemberAccess(Rc), + TypeMemberAccess(Rc), + FunctionCall(Rc), + Struct(Rc), + PrefixUnary(Rc), + Parenthesized(Rc), + Literal(Literal), + TypeInfo(TypeInfo), //TODO: need it + Uzumaki(Rc), +} + +pub enum Literal { + Array(Rc), + Bool(Rc), + String(Rc), + Number(Rc), + Unit(Rc), +} + +pub enum ArgumentType { + SelfReference(Rc), + IgnoreArgument(Rc), + Argument(Rc), + TypeInfo(TypeInfo), +} + +pub enum Misc { + StructField(Rc), +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum UnaryOperatorKind { + Neg, +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum OperatorKind { + Pow, + Add, + Sub, + Mul, + Div, + Mod, + And, + Or, + Eq, + Ne, + Lt, + Le, + Gt, + Ge, + BitAnd, + BitOr, + BitXor, + BitNot, + Shl, + Shr, +} + +pub struct SourceFile { + pub directives: Vec, + pub definitions: Vec, +} + +pub struct UseDirective { + pub imported_types: Option>, + pub segments: Option>, + pub from: Option, +} + +pub struct SpecDefinition { + pub name: String, + pub definitions: Vec, +} + +pub struct StructDefinition { + pub name: String, + pub fields: Vec>, + pub methods: Vec>, +} + +pub struct StructField { + pub name: String, + pub type_info: TypeInfo, +} + +pub struct EnumDefinition { + pub name: String, + pub variants: Vec, +} + +pub struct ConstantDefinition { + pub name: String, + pub type_info: TypeInfo, + pub value: Literal, +} + +pub struct FunctionDefinition { + pub name: String, + pub type_parameters: Option>, + pub arguments: Option>, + pub returns: Option, + pub body: BlockType, +} + +pub struct ExternalFunctionDefinition { + pub name: String, + pub arguments: Option>, + pub returns: Option, +} + +pub struct TypeDefinition { + pub name: String, + pub type_info: TypeInfo, +} + +pub struct Argument { + pub name: String, + pub is_mut: bool, + pub type_info: TypeInfo, +} + +pub struct SelfReference { + pub is_mut: bool, +} + +pub struct IgnoreArgument { + pub type_info: TypeInfo, +} + +pub struct Block { + pub statements: Vec, +} + +pub struct ExpressionStatement { + pub expression: Expression, +} + +pub struct ReturnStatement { + pub expression: Expression, +} + +pub struct LoopStatement { + pub condition: Option, + pub body: BlockType, +} + +pub struct BreakStatement {} + +pub struct IfStatement { + pub condition: Expression, + pub if_arm: BlockType, + pub else_arm: Option, +} + +pub struct VariableDefinitionStatement { + pub name: String, + pub type_info: TypeInfo, + pub value: Option, //TODO: revisit + pub is_uzumaki: bool, +} + +pub struct TypeDefinitionStatement { + pub name: String, + pub type_info: TypeInfo, +} + +pub struct AssignStatement { + pub left: Expression, + pub right: Expression, +} + +pub struct ArrayIndexAccessExpression { + pub array: Expression, + pub index: Expression, + pub type_info: TypeInfo, +} + +pub struct MemberAccessExpression { + pub expression: Expression, + pub name: String, + pub type_info: TypeInfo, +} + +pub struct TypeMemberAccessExpression { + pub expression: Expression, + pub name: String, + pub type_info: TypeInfo, +} + +pub struct FunctionCallExpression { + pub name: String, + pub function: Expression, + pub type_parameters: Option>, + pub arguments: Option, Expression)>>, + pub type_info: TypeInfo, +} + +pub struct StructExpression { + pub name: String, + pub fields: Option>, + pub type_info: TypeInfo, +} + +pub struct UzumakiExpression { + pub type_info: TypeInfo, +} + +pub struct PrefixUnaryExpression { + pub expression: Expression, + pub operator: UnaryOperatorKind, + pub type_info: TypeInfo, +} + +pub struct AssertStatement { + pub expression: Expression, +} + +pub struct ParenthesizedExpression { + pub expression: Expression, + pub type_info: TypeInfo, +} + +pub struct BinaryExpression { + pub left: Expression, + pub operator: OperatorKind, + pub right: Expression, + pub type_info: TypeInfo, +} + +pub struct ArrayLiteral { + pub elements: Option>, + pub type_info: TypeInfo, +} + +pub struct BoolLiteral { + pub value: bool, +} + +pub struct StringLiteral { + pub value: String, +} + +pub struct NumberLiteral { + pub value: String, + pub type_info: TypeInfo, +} + +pub struct UnitLiteral {} + +pub struct SimpleType { + pub name: String, + pub type_info: TypeInfo, +} + +pub struct GenericType { + pub base: String, + pub parameters: Vec, + pub type_info: TypeInfo, +} + +pub struct FunctionType { + pub parameters: Option>, + pub returns: Option, +} + +pub struct QualifiedName { + pub qualifier: String, + pub name: String, + pub type_info: TypeInfo, +} + +pub struct TypeQualifiedName { + pub alias: String, + pub name: String, + pub type_info: TypeInfo, +} + +pub struct TypeArray { + pub element_type: TypeInfo, + pub size: Option, +} diff --git a/core/hir/src/nodes_impl.rs b/core/hir/src/nodes_impl.rs new file mode 100644 index 00000000..f72169c3 --- /dev/null +++ b/core/hir/src/nodes_impl.rs @@ -0,0 +1,62 @@ +use crate::{ + nodes::{Expression, Literal, UzumakiExpression}, + type_info::{NumberTypeKindNumberType, TypeInfo, TypeInfoKind}, +}; + +impl Expression { + #[must_use] + pub fn type_info(&self) -> TypeInfo { + match self { + Expression::ArrayIndexAccess(e) => e.type_info.clone(), + Expression::MemberAccess(e) => e.type_info.clone(), + Expression::TypeMemberAccess(e) => e.type_info.clone(), + Expression::FunctionCall(e) => e.type_info.clone(), + Expression::Struct(e) => e.type_info.clone(), + Expression::PrefixUnary(e) => e.type_info.clone(), + Expression::Parenthesized(e) => e.type_info.clone(), + Expression::Binary(e) => e.type_info.clone(), + Expression::Literal(l) => l.type_info(), + Expression::TypeInfo(e) => e.clone(), + Expression::Uzumaki(e) => e.type_info.clone(), + } + } +} + +impl Literal { + #[must_use] + pub fn type_info(&self) -> TypeInfo { + match self { + Literal::Bool(_) => TypeInfo { + kind: TypeInfoKind::Bool, + type_params: vec![], + }, + Literal::Number(literal) => literal.type_info.clone(), + Literal::String(_) => TypeInfo { + kind: TypeInfoKind::String, + type_params: vec![], + }, + Literal::Unit(_) => TypeInfo { + kind: TypeInfoKind::Unit, + type_params: vec![], + }, + Literal::Array(literal) => literal.type_info.clone(), + } + } +} + +impl UzumakiExpression { + #[must_use] + pub fn is_i32(&self) -> bool { + return matches!( + self.type_info.kind, + TypeInfoKind::Number(NumberTypeKindNumberType::I32) + ); + } + #[must_use] + pub fn is_i64(&self) -> bool { + return matches!( + self.type_info.kind, + TypeInfoKind::Number(NumberTypeKindNumberType::I64) + ); + } +} diff --git a/core/ast/src/type_infer.rs b/core/hir/src/type_infer.rs similarity index 94% rename from core/ast/src/type_infer.rs rename to core/hir/src/type_infer.rs index 50e33338..acabbd50 100644 --- a/core/ast/src/type_infer.rs +++ b/core/hir/src/type_infer.rs @@ -1,11 +1,12 @@ use anyhow::bail; use crate::nodes::{ - ArgumentType, Definition, FunctionDefinition, Identifier, Literal, OperatorKind, Statement, + ArgumentType, Definition, FunctionDefinition, Literal, OperatorKind, Statement, UnaryOperatorKind, }; -use crate::nodes::{Expression, Location, SimpleType, SourceFile, Type}; +use crate::nodes::{Expression, SimpleType, SourceFile}; use crate::type_info::{NumberTypeKindNumberType, TypeInfo, TypeInfoKind}; +use inference_ast::nodes::Type; use std::collections::HashMap; use std::rc::Rc; @@ -16,6 +17,7 @@ struct FuncSignature { return_type: TypeInfo, } +#[derive(Clone, Default)] struct SymbolTable { types: HashMap, // map of type name -> type info functions: HashMap, // map of function name -> signature @@ -221,6 +223,7 @@ impl SymbolTable { } } +#[derive(Clone, Default)] pub(crate) struct TypeChecker { symbol_table: SymbolTable, errors: Vec, @@ -263,41 +266,41 @@ impl TypeChecker { match definition { Definition::Type(type_definition) => { self.symbol_table - .register_type(&type_definition.name(), Some(&type_definition.ty)) + .register_type(&type_definition.name, Some(&type_definition.ty)) .unwrap_or_else(|_| { self.errors.push(format!( "Error registering type `{}`", - type_definition.name() + type_definition.name )); }); } Definition::Struct(struct_definition) => { self.symbol_table - .register_struct(&struct_definition.name()) + .register_struct(&struct_definition.name) .unwrap_or_else(|_| { self.errors.push(format!( "Error registering type `{}`", - struct_definition.name() + struct_definition.name )); }); } Definition::Enum(enum_definition) => { self.symbol_table - .register_enum(&enum_definition.name()) + .register_enum(&enum_definition.name) .unwrap_or_else(|_| { self.errors.push(format!( "Error registering type `{}`", - enum_definition.name() + enum_definition.name )); }); } Definition::Spec(spec_definition) => { self.symbol_table - .register_spec(&spec_definition.name()) + .register_spec(&spec_definition.name) .unwrap_or_else(|_| { self.errors.push(format!( "Error registering type `{}`", - spec_definition.name() + spec_definition.name )); }); } @@ -316,7 +319,7 @@ impl TypeChecker { match definition { Definition::Constant(constant_definition) => { if let Err(err) = self.symbol_table.push_variable_to_scope( - constant_definition.name(), + constant_definition.name, TypeInfo::new(&constant_definition.ty), ) { self.errors.push(err.to_string()); @@ -358,14 +361,12 @@ impl TypeChecker { continue; } if let Err(err) = self.symbol_table.register_function( - &function_definition.name(), + &function_definition.name, function_definition .type_parameters .as_ref() - .unwrap_or(&vec![]) - .iter() - .map(|param| param.name()) - .collect::>(), + .cloned() + .unwrap_or_else(Vec::new), &function_definition .arguments .as_ref() @@ -385,7 +386,6 @@ impl TypeChecker { .as_ref() .unwrap_or(&Type::Simple(Rc::new(SimpleType::new( 0, - Location::default(), "Unit".into(), )))) .clone(), @@ -395,7 +395,7 @@ impl TypeChecker { } Definition::ExternalFunction(external_function_definition) => { if let Err(err) = self.symbol_table.register_function( - &external_function_definition.name(), + &external_function_definition.name, vec![], &external_function_definition .arguments @@ -414,11 +414,10 @@ impl TypeChecker { &external_function_definition .returns .as_ref() - .unwrap_or(&Type::Simple(Rc::new(SimpleType::new( - 0, - Location::default(), - "Unit".into(), - )))) + .unwrap_or(&Type::Simple(Rc::new(SimpleType { + name: "Unit".into(), + type_info: TypeInfo::default(), + }))) .clone(), ) { self.errors.push(err); @@ -432,7 +431,7 @@ impl TypeChecker { } } - fn validate_type(&mut self, ty: &Type, type_parameters: Option<&Vec>>) { + fn validate_type(&mut self, ty: &Type, type_parameters: Option<&Vec>) { match ty { Type::Array(type_array) => self.validate_type(&type_array.element_type, None), Type::Simple(simple_type) => { @@ -442,19 +441,15 @@ impl TypeChecker { } } Type::Generic(generic_type) => { - if self - .symbol_table - .lookup_type(&generic_type.base.name()) - .is_none() - { + if self.symbol_table.lookup_type(&generic_type.base).is_none() { self.errors - .push(format!("Unknown type `{}`", generic_type.base.name())); + .push(format!("Unknown type `{}`", generic_type.base)); } if let Some(type_params) = &type_parameters { if type_params.len() != generic_type.parameters.len() { self.errors.push(format!( "Type parameter count mismatch for `{}`: expected {}, found {}", - generic_type.base.name(), + generic_type.base, generic_type.parameters.len(), type_params.len() )); @@ -462,14 +457,13 @@ impl TypeChecker { let generic_param_names: Vec = generic_type .parameters .iter() - .map(|param| param.name()) + .map(|param| param.clone()) .collect(); for param in &generic_type.parameters { - if !generic_param_names.contains(¶m.name()) { + if !generic_param_names.contains(param) { self.errors.push(format!( "Type parameter `{}` not found in `{}`", - param.name(), - generic_type.base.name() + param, generic_type.base )); } } @@ -497,7 +491,7 @@ impl TypeChecker { ArgumentType::Argument(arg) => { if let Err(err) = self .symbol_table - .push_variable_to_scope(arg.name(), TypeInfo::new(&arg.ty)) + .push_variable_to_scope(arg.name, TypeInfo::new(&arg.ty)) { self.errors.push(err.to_string()); } @@ -617,7 +611,7 @@ impl TypeChecker { } } if let Err(err) = self.symbol_table.push_variable_to_scope( - variable_definition_statement.name(), + &variable_definition_statement.name, TypeInfo::new(&variable_definition_statement.ty), ) { self.errors.push(err.to_string()); @@ -647,7 +641,7 @@ impl TypeChecker { let constant_type = TypeInfo::new(&constant_definition.ty); if let Err(err) = self .symbol_table - .push_variable_to_scope(constant_definition.name(), constant_type) + .push_variable_to_scope(constant_definition.name, constant_type) { self.errors.push(err.to_string()); } @@ -723,13 +717,13 @@ impl TypeChecker { Expression::FunctionCall(function_call_expression) => { let signature = if let Some(s) = self .symbol_table - .lookup_function(&function_call_expression.name()) + .lookup_function(&function_call_expression.name) { s.clone() } else { self.errors.push(format!( "Call to undefined function `{}`", - function_call_expression.name() + function_call_expression.name )); if let Some(arguments) = &function_call_expression.arguments { for arg in arguments { @@ -743,7 +737,7 @@ impl TypeChecker { { self.errors.push(format!( "Function `{}` expects {} arguments, but {} provided", - function_call_expression.name(), + function_call_expression.name, signature.param_types.len(), arguments.len() )); @@ -781,7 +775,7 @@ impl TypeChecker { } let struct_type = self .symbol_table - .lookup_type(&struct_expression.name()) + .lookup_type(&struct_expression.name) .cloned(); if let Some(struct_type) = struct_type { *struct_expression.type_info.borrow_mut() = Some(struct_type.clone()); @@ -789,7 +783,7 @@ impl TypeChecker { } self.errors.push(format!( "Struct `{}` is not defined", - struct_expression.name() + struct_expression.name )); None } @@ -956,13 +950,13 @@ impl TypeChecker { } (Type::Simple(left), Type::Simple(right)) => left.name == right.name, (Type::Generic(left), Type::Generic(right)) => { - left.base.name() == right.base.name() && left.parameters == right.parameters + left.base == right.base && left.parameters == right.parameters } - (Type::Qualified(left), Type::Qualified(right)) => left.name() == right.name(), + (Type::Qualified(left), Type::Qualified(right)) => left.name == right.name, (Type::QualifiedName(left), Type::QualifiedName(right)) => { - left.qualifier() == right.qualifier() && left.name() == right.name() + left.qualifier == right.qualifier && left.name == right.name } - (Type::Custom(left), Type::Custom(right)) => left.name() == right.name(), + (Type::Custom(left), Type::Custom(right)) => left.name == right.name, (Type::Function(left), Type::Function(right)) => { let left_has_return_type = left.returns.is_some(); let right_has_return_type = right.returns.is_some(); diff --git a/core/ast/src/type_inference.rs b/core/hir/src/type_inference.rs similarity index 99% rename from core/ast/src/type_inference.rs rename to core/hir/src/type_inference.rs index 6178c5ce..05a6fa9f 100644 --- a/core/ast/src/type_inference.rs +++ b/core/hir/src/type_inference.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::types::{Expression, FunctionDefinition, OperatorKind, Statement, Type}; +use crate::nodes::{Expression, OperatorKind, Type}; /// The `type_inference` module provides functionality for performing multi-pass /// type checking and inference on a Rust-like abstract syntax tree (AST). @@ -47,7 +47,6 @@ use crate::types::{Expression, FunctionDefinition, OperatorKind, Statement, Type /// `TypeChecker.infer_types(&mut ast)` to annotate it or produce errors.) /// /// - /// Symbol table entry for a type definition. #[derive(Debug)] struct TypeInfo { diff --git a/core/ast/src/type_info.rs b/core/hir/src/type_info.rs similarity index 94% rename from core/ast/src/type_info.rs rename to core/hir/src/type_info.rs index 9532c318..7017bb8e 100644 --- a/core/ast/src/type_info.rs +++ b/core/hir/src/type_info.rs @@ -1,7 +1,7 @@ use core::fmt; use std::fmt::{Display, Formatter}; -use crate::nodes::Type; +use inference_ast::nodes::Type; #[derive(Debug, Eq, PartialEq, Clone)] pub enum NumberTypeKindNumberType { @@ -113,19 +113,23 @@ impl TypeInfo { type_params: vec![], }, Type::Generic(generic) => Self { - kind: TypeInfoKind::Generic(generic.base.name.clone()), - type_params: generic.parameters.iter().map(|p| p.name.clone()).collect(), + kind: TypeInfoKind::Generic(generic.base.name().clone()), + type_params: generic + .parameters + .iter() + .map(|p| p.name().clone()) + .collect(), }, Type::QualifiedName(qualified_name) => Self { kind: TypeInfoKind::QualifiedName(format!( "{}::{}", - qualified_name.qualifier(), - qualified_name.name() + qualified_name.qualifier.name(), + qualified_name.name.name() )), type_params: vec![], }, Type::Qualified(qualified) => Self { - kind: TypeInfoKind::Qualified(qualified.name.name.clone()), + kind: TypeInfoKind::Qualified(qualified.name.name().clone()), type_params: vec![], }, Type::Array(array) => Self { diff --git a/core/wasm-codegen/Cargo.toml b/core/wasm-codegen/Cargo.toml index 4e9ec557..f44cac8c 100644 --- a/core/wasm-codegen/Cargo.toml +++ b/core/wasm-codegen/Cargo.toml @@ -10,5 +10,5 @@ repository = { workspace = true } inkwell = { version = "0.7.1", features = ["llvm21-1"] } tempfile = "3.3.0" which = "8.0.0" -inference-ast.workspace = true +inference-hir.workspace = true anyhow.workspace = true diff --git a/core/wasm-codegen/src/compiler.rs b/core/wasm-codegen/src/compiler.rs index 8869340b..034c0ce1 100644 --- a/core/wasm-codegen/src/compiler.rs +++ b/core/wasm-codegen/src/compiler.rs @@ -1,7 +1,7 @@ //TODO: don't forget to remove #![allow(dead_code)] use crate::utils; -use inference_ast::{ +use inference_hir::{ nodes::{BlockType, Expression, FunctionDefinition, Literal, Statement, Type}, type_info::{NumberTypeKindNumberType, TypeInfoKind}, }; From af5f73f93fb33a7759254e35cd6e78a53bc84bea Mon Sep 17 00:00:00 2001 From: Georgii Plotnikov Date: Mon, 15 Dec 2025 15:57:38 +0900 Subject: [PATCH 2/4] WIP --- core/ast/src/lib.rs | 1 + core/ast/src/node_type.rs | 311 ++++ core/hir/Cargo.toml | 3 +- .../hir/src/{type_infer.rs => _type_infer.rs} | 0 .../{type_inference.rs => _type_inference.rs} | 0 core/hir/src/hir.rs | 6 +- core/hir/src/lib.rs | 6 +- core/hir/src/nodes.rs | 2 + core/hir/src/symbol_table.rs | 1616 +++++++++++++++++ 9 files changed, 1938 insertions(+), 7 deletions(-) create mode 100644 core/ast/src/node_type.rs rename core/hir/src/{type_infer.rs => _type_infer.rs} (100%) rename core/hir/src/{type_inference.rs => _type_inference.rs} (100%) create mode 100644 core/hir/src/symbol_table.rs diff --git a/core/ast/src/lib.rs b/core/ast/src/lib.rs index fa3abec1..8df1afac 100644 --- a/core/ast/src/lib.rs +++ b/core/ast/src/lib.rs @@ -2,5 +2,6 @@ pub mod arena; pub mod builder; pub(crate) mod enums_impl; +pub mod node_type; pub mod nodes; pub(crate) mod nodes_impl; diff --git a/core/ast/src/node_type.rs b/core/ast/src/node_type.rs new file mode 100644 index 00000000..a762bc31 --- /dev/null +++ b/core/ast/src/node_type.rs @@ -0,0 +1,311 @@ +//! AST node types for Rust type annotations. +//! +//! Defines the `NodeType` enum representing parsed type expressions such as paths, +//! references, pointers, tuples, arrays, and more. + +use std::rc::Rc; + +use crate::nodes::{ + Definition, Directive, EnumDefinition, Expression, FunctionCallExpression, FunctionDefinition, + Literal, Location, SourceFile, Statement, StructDefinition, Type, +}; +pub type RcFile = Rc; +pub type RcContract = Rc; +pub type RcFunction = Rc; +pub type RcExpression = Rc; +pub type RcFunctionCall = Rc; + +pub type RcEnum = Rc; +pub type RcStruct = Rc; + +#[derive(Clone, PartialEq, Eq, Debug, Default, serde::Serialize, serde::Deserialize)] +pub enum NodeType { + #[default] + Empty, + /// A named type or path, including any generics as represented in the token stream + Path(String), + /// A reference `&T` or `&mut T`, with explicit flag + Reference { + inner: Box, + mutable: bool, + is_explicit_reference: bool, + }, + /// A raw pointer `*const T` or `*mut T` + Ptr { inner: Box, mutable: bool }, + /// A tuple type `(T1, T2, ...)` + Tuple(Vec), + /// An array type `[T; len]`, with optional length if parseable + Array { + inner: Box, + len: Option, + }, + /// A slice type `[T]` + Slice(Box), + /// A bare function pointer `fn(a, b) -> R` + BareFn { + inputs: Vec, + output: Box, + }, + /// A generic type annotation, e.g., `Option`, `Result` + Generic { + base: Box, + args: Vec, + }, + /// A trait object type `dyn Trait1 + Trait2` + TraitObject(Vec), + /// An `impl Trait` type + ImplTrait(Vec), + Closure { + inputs: Vec, + output: Box, + }, +} + +impl NodeType { + #[must_use] + pub fn name(&self) -> String { + match self { + NodeType::Path(name) => name.clone(), + NodeType::Reference { + inner, + is_explicit_reference, + .. + } => { + if *is_explicit_reference { + format!("&{}", inner.name()) + } else { + inner.name() + } + } + NodeType::Ptr { inner, mutable } => { + let star = if *mutable { "*mut" } else { "*const" }; + format!("{} {}", star, inner.name()) + } + NodeType::Tuple(elems) => format!( + "({})", + elems + .iter() + .map(NodeType::name) + .collect::>() + .join(", ") + ), + NodeType::Array { inner, len } => format!( + "[{}; {}]", + inner.name(), + len.map_or("..".to_string(), |l| l.to_string()) + ), + NodeType::Slice(inner) => format!("[{}]", inner.name()), + NodeType::BareFn { inputs, output } => { + let mut result = inputs + .iter() + .map(NodeType::name) + .collect::>() + .join(", "); + if result.is_empty() { + result = "_".to_string(); + } + let output = if output.name().is_empty() { + "_".to_string() + } else { + output.name() + }; + format!("fn({result}) -> {output}") + } + NodeType::Closure { inputs, output } => { + let mut result = inputs + .iter() + .map(NodeType::name) + .collect::>() + .join(", "); + if result.is_empty() { + result = "_".to_string(); + } + let output = if output.name().is_empty() { + "_".to_string() + } else { + output.name() + }; + format!("{result} || -> {output}") + } + NodeType::Generic { base, args } => format!( + "{}<{}>", + base.name(), + args.iter() + .map(NodeType::name) + .collect::>() + .join(", ") + ), + NodeType::TraitObject(bounds) => format!("dyn {}", bounds.join(" + ")), + NodeType::ImplTrait(bounds) => format!("impl {}", bounds.join(" + ")), + NodeType::Empty => String::from("_"), + } + } + + #[must_use] + pub fn pure_name(&self) -> String { + match self { + NodeType::Path(name) => name.clone(), + NodeType::Reference { inner, .. } + | NodeType::Ptr { inner, .. } + | NodeType::Array { inner, len: _ } + | NodeType::Slice(inner) => inner.pure_name(), + NodeType::Tuple(elems) => format!( + "({})", + elems + .iter() + .map(NodeType::pure_name) + .collect::>() + .join(", ") + ), + NodeType::BareFn { inputs, output } => { + let mut result = inputs + .iter() + .map(NodeType::pure_name) + .collect::>() + .join(", "); + if result.is_empty() { + result = "_".to_string(); + } + let output = output.pure_name(); + format!("fn({result}) -> {output}") + } + NodeType::Closure { inputs, output } => { + let mut result = inputs + .iter() + .map(NodeType::pure_name) + .collect::>() + .join(", "); + if result.is_empty() { + result = "_".to_string(); + } + let output = output.pure_name(); + format!("{result} || -> {output}") + } + NodeType::Generic { base, args } => format!( + "{}<{}>", + base.pure_name(), + args.iter() + .map(NodeType::pure_name) + .collect::>() + .join(", ") + ), + NodeType::TraitObject(bounds) | NodeType::ImplTrait(bounds) => bounds.join(" + "), + NodeType::Empty => String::from("_"), + } + } + + #[must_use] + pub fn is_self(&self) -> bool { + match self { + NodeType::Path(name) => name.to_lowercase() == "self", + NodeType::Reference { inner, .. } + | NodeType::Ptr { inner, .. } + | NodeType::Array { inner, .. } + | NodeType::Slice(inner) => inner.is_self(), + NodeType::Tuple(elems) => elems.iter().any(NodeType::is_self), + NodeType::BareFn { inputs, output } | NodeType::Closure { inputs, output } => { + inputs.iter().any(NodeType::is_self) || output.is_self() + } + NodeType::Generic { base, args } => { + base.is_self() || args.iter().any(NodeType::is_self) + } + NodeType::TraitObject(bounds) | NodeType::ImplTrait(bounds) => { + bounds.iter().any(|b| b.to_lowercase() == "self") + } + NodeType::Empty => false, + } + } + + #[allow(clippy::assigning_clones)] + pub fn replace_path(&mut self, new_path: String) { + match self { + NodeType::Path(_) => { + *self = NodeType::Path(new_path); + } + NodeType::Reference { inner, .. } + | NodeType::Ptr { inner, .. } + | NodeType::Array { inner, .. } + | NodeType::Slice(inner) => { + inner.replace_path(new_path); + } + NodeType::Tuple(elems) => { + for elem in elems { + elem.replace_path(new_path.clone()); + } + } + NodeType::BareFn { inputs, output } | NodeType::Closure { inputs, output } => { + for input in inputs { + input.replace_path(new_path.clone()); + } + output.replace_path(new_path); + } + NodeType::Generic { base, args } => { + base.replace_path(new_path.clone()); + for arg in args { + arg.replace_path(new_path.clone()); + } + } + NodeType::TraitObject(bounds) | NodeType::ImplTrait(bounds) => { + for bound in bounds.iter_mut() { + if bound.to_lowercase() == "self" { + *bound = new_path.clone(); + } + } + } + NodeType::Empty => {} + } + } +} + +#[derive(Clone, Debug)] +pub enum NodeKind { + File(Rc), + Directive(Directive), + Definition(Definition), + Statement(Statement), + Expression(Expression), + Literal(Literal), + Type(Type), +} + +impl NodeKind { + #[must_use] + pub fn id(&self) -> u32 { + match self { + NodeKind::File(f) => f.id, + NodeKind::Definition(d) => d.id(), + NodeKind::Directive(d) => d.id(), + NodeKind::Statement(s) => s.id(), + NodeKind::Expression(e) => e.id(), + NodeKind::Literal(l) => l.id(), + NodeKind::Type(t) => t.id(), + } + } + + #[must_use] + #[allow(clippy::cast_possible_truncation)] + pub fn location(&self) -> Location { + match self { + NodeKind::File(f) => f.location().clone(), + NodeKind::Definition(d) => d.location().clone(), + NodeKind::Directive(d) => d.location(), + NodeKind::Statement(s) => s.location(), + NodeKind::Expression(e) => e.location(), + NodeKind::Literal(l) => l.location(), + NodeKind::Type(t) => t.location(), + } + } + + #[must_use] + pub fn children(&self) -> Vec { + match self { + NodeKind::File(file) => file.children(), + NodeKind::Definition(definition) => definition.children(), + NodeKind::Directive(directive) => directive.children(), + NodeKind::Statement(statement) => statement.children(), + NodeKind::Expression(expression) => expression.children(), + NodeKind::Literal(literal) => literal.children(), + NodeKind::Type(ty) => ty.children(), + } + } +} diff --git a/core/hir/Cargo.toml b/core/hir/Cargo.toml index 2e146811..de60397a 100644 --- a/core/hir/Cargo.toml +++ b/core/hir/Cargo.toml @@ -8,4 +8,5 @@ repository = { workspace = true } [dependencies] inference-ast.workspace = true -anyhow.workspace = true \ No newline at end of file +anyhow.workspace = true +serde.workspace = true \ No newline at end of file diff --git a/core/hir/src/type_infer.rs b/core/hir/src/_type_infer.rs similarity index 100% rename from core/hir/src/type_infer.rs rename to core/hir/src/_type_infer.rs diff --git a/core/hir/src/type_inference.rs b/core/hir/src/_type_inference.rs similarity index 100% rename from core/hir/src/type_inference.rs rename to core/hir/src/_type_inference.rs diff --git a/core/hir/src/hir.rs b/core/hir/src/hir.rs index 8fc4316d..2d6327c3 100644 --- a/core/hir/src/hir.rs +++ b/core/hir/src/hir.rs @@ -1,10 +1,10 @@ -use crate::{arena::Arena, type_infer::TypeChecker}; +use crate::{arena::Arena, symbol_table::SymbolTable}; use inference_ast::arena::Arena as AstArena; #[derive(Clone, Default)] pub struct Hir { pub arena: Arena, - pub type_checker: TypeChecker, + pub symbol_table: SymbolTable, } impl Hir { @@ -12,7 +12,7 @@ impl Hir { pub fn new(arena: AstArena) -> Self { Self { arena: Arena::default(), - type_checker: TypeChecker::default(), + symbol_table: SymbolTable::default(), } } } diff --git a/core/hir/src/lib.rs b/core/hir/src/lib.rs index c6d16d5f..9cdff681 100644 --- a/core/hir/src/lib.rs +++ b/core/hir/src/lib.rs @@ -1,8 +1,8 @@ #![warn(clippy::pedantic)] pub mod hir; -pub mod nodes; +mod nodes; mod nodes_impl; -pub mod type_infer; +mod symbol_table; // mod type_inference; -pub mod arena; +mod arena; pub mod type_info; diff --git a/core/hir/src/nodes.rs b/core/hir/src/nodes.rs index e7db761a..29adac5e 100644 --- a/core/hir/src/nodes.rs +++ b/core/hir/src/nodes.rs @@ -6,6 +6,7 @@ pub enum Directive { Use(Rc), } +#[derive(Clone, Debug)] pub enum Definition { Spec(Rc), Struct(Rc), @@ -105,6 +106,7 @@ pub struct SourceFile { pub definitions: Vec, } +#[derive(Debug)] pub struct UseDirective { pub imported_types: Option>, pub segments: Option>, diff --git a/core/hir/src/symbol_table.rs b/core/hir/src/symbol_table.rs new file mode 100644 index 00000000..2aca0d49 --- /dev/null +++ b/core/hir/src/symbol_table.rs @@ -0,0 +1,1616 @@ +//! Symbol table implementation for Inference code. +//! +//! Defines scopes, definitions, and name resolution logic, tracking symbols +//! (types, functions, variables) across modules. +// use serde::{Deserialize, Serialize}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; + +use inference_ast::nodes::{ + Block, Definition, Expression, FunctionDefinition, Literal, OperatorKind, Statement, + UseDirective, +}; + +pub(crate) type ScopeRef = Rc>; + +#[derive(Clone)] +pub(crate) enum DefinitionRef { + Ref(String, Definition), + QualifiedName(String), +} + +impl DefinitionRef { + pub(crate) fn name(&self) -> String { + match self { + DefinitionRef::QualifiedName(name) | DefinitionRef::Ref(name, _) => name.clone(), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct Scope { + pub(crate) id: u32, + pub(crate) name: String, + pub(crate) parent: Option, + pub(crate) children: Vec, + pub(crate) imports: Vec>, + import_aliases: HashMap, + pub(crate) definitions: HashMap, + variables: HashMap, + /// Structs or enums and their methods + methods: HashMap>>, +} + +impl Scope { + pub(crate) fn new(id: u32, name: String, parent: Option) -> ScopeRef { + Rc::new(RefCell::new(Scope { + id, + name, + parent, + children: Vec::new(), + imports: Vec::new(), + import_aliases: HashMap::new(), + definitions: HashMap::new(), + variables: HashMap::new(), + methods: HashMap::new(), + })) + } + + pub fn visible_child(&self, ident: &str) -> Option { + if let Some(sub) = self + .children + .iter() + .find(|s| s.borrow().name.rsplit("::").next() == Some(ident)) + { + return Some(DefOrScopeRef::Module(sub.clone())); + } + + if let Some(def) = self.definitions.get(ident) { + return Some(DefOrScopeRef::Definition(def.clone())); + } + + for u in &self.imports { + for (key, def_opt) in u.target.borrow().iter() { + let bound = if let Some((_, alias)) = key.split_once('%') { + alias + } else { + key.rsplit("::").next().unwrap_or(key) + }; + if bound == ident { + if let Some(def) = def_opt { + return Some(DefOrScopeRef::Definition(def.clone())); + } + } + } + } + + None + } + + fn get_crate_name(&self) -> String { + if let Some(parent) = &self.parent { + parent.borrow().get_crate_name() + } else { + self.name.clone() + } + } + + fn relative_to_absolute_path(&self, path: &str) -> String { + if path.starts_with("crate") { + let crate_name = self.get_crate_name(); + return path.replace("crate", &crate_name); + } + path.to_string() + } + + fn resolve_self(&self) -> Option { + if let Some((_, ty)) = self.variables.get("self") { + return self.lookup_def(&ty.name()); + } + if let Some(parent) = &self.parent { + return parent.borrow().resolve_self(); + } + None + } + + fn lookup_def(&self, name: &str) -> Option { + let mut name: &str = name; + if name.contains("::") { + name = name.split("::").last().unwrap_or(name); + } + if self.import_aliases.contains_key(name) { + name = self.import_aliases.get(name).unwrap(); + } + if let Some(def) = self.try_get_definition(name) { + if self + .definitions + .keys() + .find_map(|def_name| { + if def_name == name { + Some(DefinitionRef::Ref(def_name.clone(), def.clone())) + } else { + None + } + }) + .is_some() + { + let q_name = format!("{}::{}", self.name, def.name()); + return Some(DefinitionRef::Ref(q_name, def)); + } + return Some(DefinitionRef::Ref(name.to_string(), def.clone())); + } + for import in &self.imports { + for it in &import.imported_types { + if it == name || it.ends_with(name) { + if let Some(Some(def)) = import.target.borrow().get(it) { + return Some(DefinitionRef::Ref(it.clone(), def.clone())); + } + return Some(DefinitionRef::QualifiedName(it.clone())); + } + } + } + if let Some(parent) = &self.parent { + return parent.borrow().lookup_def(name); + } + None + } + + fn insert_var(&mut self, name: String, id: u32, ty: NodeType) { + self.variables.insert(name, (id, ty)); + } + + fn lookup_symbol(&self, name: &str) -> Option { + if let Some(ty) = self.variables.get(name) { + return Some(ty.1.clone()); + } + if let Some(parent) = &self.parent { + return parent.borrow().lookup_symbol(name); + } + None + } + + pub(crate) fn try_get_definition(&self, name: &str) -> Option { + self.definitions.get(name).cloned() + } + + fn qualify_definition_name(&self, name: &str) -> Option { + if let Some(res) = self.definitions.keys().find_map(|def_name| { + if def_name == name { + Some(def_name.clone()) + } else { + None + } + }) { + return Some(res); + } + if let Some(parent) = &self.parent { + return parent.borrow().qualify_definition_name(name); + } + None + } + + pub fn crate_root_id(&self) -> u32 { + if let Some(parent) = &self.parent { + parent.borrow().crate_root_id() + } else { + self.id + } + } + + pub fn crate_root(&self) -> Option { + if let Some(parent) = &self.parent { + if parent.borrow().parent.is_none() { + Some(parent.clone()) + } else { + parent.borrow().crate_root() + } + } else { + None + } + } + + fn functions(&self) -> impl Iterator> { + self.definitions.values().filter_map(|def| { + if let Definition::Function(f) = def { + Some(f) + } else { + None + } + }) + } + + fn methods(&self) -> impl Iterator> { + self.methods.values().flat_map(|methods| methods.iter()) + } +} + +// #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +// struct DefinitionName { +// name: String, +// qualified_name: String, +// } + +pub(crate) enum DefOrScopeRef { + Module(ScopeRef), + Definition(Definition), +} +impl DefOrScopeRef { + pub(crate) fn as_module(&self) -> Option>> { + match self { + DefOrScopeRef::Module(scope) => Some(scope.clone()), + DefOrScopeRef::Definition(_) => None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SymbolTable { + pub(crate) scopes: HashMap, + mod_scopes: HashMap, + defs: HashMap<(u32, String), Definition>, +} + +impl Default for SymbolTable { + fn default() -> Self { + SymbolTable::new() + } +} + +impl SymbolTable { + pub(crate) fn new() -> Self { + SymbolTable { + scopes: HashMap::new(), + mod_scopes: HashMap::new(), + defs: HashMap::new(), + } + } + + #[allow(clippy::needless_pass_by_value)] + pub(crate) fn insert_scope(&mut self, scope: ScopeRef) { + self.scopes.insert(scope.borrow().id, scope.clone()); + self.mod_scopes + .insert(scope.borrow().name.clone(), scope.clone()); + } + + pub(crate) fn insert_def( + &mut self, + scope_id: u32, + qualified_name: String, + def: Definition, + ) -> Option { + self.defs.insert((scope_id, qualified_name), def) + } + + fn get_scope_by_def_id(&self, id: u32) -> Option { + for ((s_id, _), v) in &self.defs { + if v.id() == id { + return self.scopes.get(s_id).cloned(); + } + } + None + } + + pub(crate) fn build_symbol_tables(&mut self) { + let scopes = self.scopes.values().cloned().collect::>(); + for scope in scopes { + let functions = scope + .borrow() + .definitions + .iter() + .filter_map(|(_, def)| { + if let Definition::Function(f) = def { + Some(f.clone()) + } else { + None + } + }) + .collect::>(); + for mut function in functions { + self.build_function_symbol_table(&scope, &mut function); + } + let scope_borrow = scope.borrow(); + let methods = scope_borrow.methods.values().cloned().collect::>(); + drop(scope_borrow); + for methods_group in methods { + for method in methods_group { + self.build_function_symbol_table(&scope, &mut method.clone()); + } + } + let scope_borrow = scope.borrow(); + let functions = scope_borrow.functions().collect::>(); + for function in functions { + self.build_function_symbol_table(&scope, &mut function.clone()); + } + } + } + + fn build_function_symbol_table(&mut self, scope: &ScopeRef, f: &mut Rc) { + let self_type = self.find_self_type_for_method(f.clone()); + let fn_scope = self + .scopes + .get(&f.id) + .unwrap_or_else(|| { + panic!( + "Function scope with id {} not found for function {}", + f.id, f.name + ) + }) + .clone(); + for p in &f.parameters { + let mut param_ty = NodeType::from_string(&p.type_name); + if param_ty.is_self() { + if let Some(self_ty_name) = &self_type { + param_ty.replace_path(self_ty_name.clone()); + } + } + let param_pure_type_name = param_ty.pure_name(); + + if let Some(DefinitionRef::Ref(qname, _)) = + scope.borrow().lookup_def(¶m_pure_type_name) + { + param_ty.replace_path(qname.clone()); + } + + fn_scope + .borrow_mut() + .insert_var(p.name.clone(), p.id, param_ty); + } + let ty_node_name = &f.returns.borrow().pure_name(); + if let Some(def_ref) = scope.borrow().lookup_def(ty_node_name) { + f.returns.borrow_mut().replace_path(def_ref.name()); + } + if let Some(block) = &f.body { + for stmt in &block.statements { + process_statement(stmt, &fn_scope, self); + } + } + } + + #[must_use] + pub fn lookup_symbol_in_scope(&self, scope_id: u32, name: &str) -> Option { + if self.scopes.contains_key(&scope_id) { + return self + .scopes + .get(&scope_id) + .unwrap() + .borrow() + .lookup_symbol(name); + } + let mut stack: Vec = self.scopes.values().cloned().collect(); + while let Some(scope) = stack.pop() { + if scope.borrow().id == scope_id { + return scope.borrow().lookup_symbol(name); + } + for child in &scope.borrow().children { + stack.push(child.clone()); + } + } + None + } + + #[must_use] + pub fn infer_expr_type(&self, scope_id: u32, expr: &Expression) -> NodeType { + if let Some(scope) = self.scopes.get(&scope_id) { + return infer_expr_type(expr, scope, self); + } + let mut stack: Vec = self.scopes.values().cloned().collect(); + while let Some(scope) = stack.pop() { + if scope.borrow().id == scope_id { + return infer_expr_type(expr, &scope, self); + } + for child in &scope.borrow().children { + stack.push(child.clone()); + } + } + NodeType::Empty + } + + #[allow(clippy::needless_pass_by_value)] + pub(crate) fn find_self_type_for_method( + &self, + function: Rc, + ) -> Option { + fn find_in_scope(scope: &ScopeRef, function_name: &str) -> Option { + let scope_b = scope.borrow(); + for (type_name, methods) in &scope_b.methods { + if methods.iter().any(|m| m.name == function_name) { + return Some(type_name.clone()); + } + } + for child in &scope_b.children { + if let Some(name) = find_in_scope(child, function_name) { + return Some(name); + } + } + None + } + for scope in self.mod_scopes.values() { + if let Some(name) = find_in_scope(scope, &function.name) { + return Some(name); + } + } + None + } + + pub(crate) fn get_function_by_name(&self, scope_id: u32, name: &str) -> Option { + if let Some(scope) = &self.scopes.get(&scope_id) { + if let Some(def_ref) = scope.borrow().lookup_def(name) { + match def_ref { + DefinitionRef::Ref(_, def) => return Some(def), + DefinitionRef::QualifiedName(q_name) => { + for d in &self.defs { + if d.0.1 == q_name { + return Some(d.1.clone()); + } + } + } + } + } + } + None + } + + #[allow(unused_variables, clippy::too_many_lines)] + pub(crate) fn lookup_symbol_origin( + &self, + scope_id: u32, + symbol: &str, + ) -> Option { + fn find_stmt(block: &Block, id: u32) -> Option { + for stmt in &block.statements { + if stmt.id() == id { + return Some(stmt.clone()); + } + match stmt { + Statement::Block(inner) => { + if let Some(s) = find_stmt(inner, id) { + return Some(s); + } + } + Statement::Expression(expr) => match expr { + Expression::If(if_expr) => { + if let Some(s) = find_stmt(&if_expr.then_branch, id) { + return Some(s); + } + if let Some(Expression::EBlock(inner)) = &if_expr.else_branch { + if let Some(s) = find_stmt(&inner.block, id) { + return Some(s); + } + } + } + Expression::ForLoop(for_loop) => { + if let Some(s) = find_stmt(&for_loop.block, id) { + return Some(s); + } + } + Expression::Loop(loop_stmt) => { + if let Some(s) = find_stmt(&loop_stmt.block, id) { + return Some(s); + } + } + Expression::TryBlock(try_block) => { + if let Some(s) = find_stmt(&try_block.block, id) { + return Some(s); + } + } + Expression::EBlock(inner) => { + if let Some(s) = find_stmt(&inner.block, id) { + return Some(s); + } + } + _ => {} + }, + _ => {} + } + } + None + } + fn find_call_sites( + block: &Block, + name: &str, + param_idx: usize, + caller_id: u32, + calls: &mut Vec<(u32, Expression)>, + ) { + for stmt in &block.statements { + if let Statement::Expression(Expression::FunctionCall(fc)) = stmt { + if fc.function_name == name { + if let Some(arg) = fc.parameters.get(param_idx) { + calls.push((caller_id, arg.clone())); + } + } + } + if let Statement::Block(inner) = stmt { + find_call_sites(inner, name, param_idx, caller_id, calls); + } + if let Statement::Expression(Expression::If(if_expr)) = stmt { + find_call_sites(&if_expr.then_branch, name, param_idx, caller_id, calls); + if let Some(Expression::EBlock(inner)) = &if_expr.else_branch { + find_call_sites(&inner.block, name, param_idx, caller_id, calls); + } + } + if let Statement::Expression(Expression::ForLoop(for_loop)) = stmt { + find_call_sites(&for_loop.block, name, param_idx, caller_id, calls); + } + if let Statement::Expression(Expression::Loop(loop_stmt)) = stmt { + find_call_sites(&loop_stmt.block, name, param_idx, caller_id, calls); + } + } + } + if let Some(scope_rc) = self.scopes.get(&scope_id) { + let scope = scope_rc.borrow(); + if let Some((var_id, _)) = scope.variables.get(symbol) { + if let Some(parent_rc) = &scope.parent { + let parent = parent_rc.borrow(); + if let Some(func_name) = scope.name.rsplit("::").next() { + if let Some(DefinitionRef::Ref(_, Definition::Function(func_def))) = + parent.lookup_def(func_name) + { + if let Some((idx, param)) = func_def + .parameters + .iter() + .enumerate() + .find(|(_, p)| p.id == *var_id) + { + let mut calls = Vec::new(); + for def in self.defs.values() { + if let Definition::Function(caller_def) = def { + if caller_def.id != func_def.id { + if let Some(body) = &caller_def.body { + find_call_sites( + body, + func_name, + idx, + caller_def.id, + &mut calls, + ); + } + } + } + } + if calls.len() == 1 { + let (caller_id, arg_expr) = &calls[0]; + if let Expression::Identifier(id) = arg_expr { + return self.lookup_symbol_origin(*caller_id, &id.name); + } + } + return Some(crate::node_type::NodeKind::Misc(Misc::FnParameter( + param.clone(), + ))); + } + if let Some(body) = &func_def.body { + if let Some(stmt) = find_stmt(body, *var_id) { + return Some(crate::node_type::NodeKind::Statement(stmt)); + } + } + } + } + } + } + } + + let mut current = self.scopes.get(&scope_id).cloned(); + while let Some(scope_rc) = current { + let scope = scope_rc.borrow(); + if let Some(def_ref) = scope.lookup_def(symbol) { + match def_ref { + DefinitionRef::Ref(_, def) => { + return Some(crate::node_type::NodeKind::Definition(def.clone())); + } + DefinitionRef::QualifiedName(qn) => { + for ((_, name), def) in &self.defs { + if *name == qn { + return Some(crate::node_type::NodeKind::Definition(def.clone())); + } + } + } + } + } + current = scope.parent.as_ref().map(Rc::clone); + } + None + } +} + +pub(crate) fn process_definition( + parent_scope: &ScopeRef, + def: Definition, + table: &mut SymbolTable, +) { + let parent_path = parent_scope.borrow().name.clone(); + let name = def.name(); + let qualified = format!("{parent_path}::{name}"); + parent_scope + .borrow_mut() + .definitions + .insert(name.clone(), def.clone()); + table.insert_def(parent_scope.borrow().id, qualified.clone(), def.clone()); + match def { + Definition::Struct(s) => { + parent_scope + .borrow_mut() + .methods + .insert(s.name.clone(), Vec::new()); + } + Definition::Enum(e) => { + parent_scope + .borrow_mut() + .methods + .insert(e.name.clone(), Vec::new()); + } + Definition::Function(f) => { + let fn_scope = Scope::new(f.id, qualified, Some(parent_scope.clone())); + table.insert_scope(fn_scope.clone()); + } + Definition::Implementation(i) => { + for constant in &i.constants { + let constant_name = constant.name.clone(); + let constant_def = Definition::Const(constant.clone()); + parent_scope + .borrow_mut() + .definitions + .insert(constant_name, constant_def); + } + for ta in &i.type_aliases { + let alias_name = ta.name.clone(); + let alias_def = Definition::TypeAlias(ta.clone()); + parent_scope + .borrow_mut() + .definitions + .insert(alias_name, alias_def); + } + if let Type::Typename(target) = &i.for_type { + let def_name_op = parent_scope.borrow().qualify_definition_name(&target.name); + if let Some(def_name) = def_name_op { + if let Some(method_list) = parent_scope.borrow_mut().methods.get_mut(&def_name) + { + for func in &i.functions { + method_list.push(func.clone()); + } + } + for func in &i.functions { + process_definition( + &parent_scope.clone(), + Definition::Function(func.clone()), + table, + ); + } + } + } + } + _ => {} + } +} + +pub(crate) fn fixpoint_resolver(table: &mut SymbolTable, extern_prelude: &mut ExternPrelude) { + loop { + let mut progress = false; + for scope in table.scopes.values() { + for rc_use in &scope.borrow().imports { + if rc_use.is_resolved() { + continue; + } + let imported_types = rc_use.imported_types.clone(); + for import_path in imported_types { + let (orig_path, _) = import_path + .split_once('%') + .map(|(orig, alias)| (orig.to_string(), Some(alias.to_string()))) + .unwrap_or((import_path.clone(), None)); + let (head, tail) = match orig_path.split_once("::") { + Some((h, t)) => (h.to_string(), t.to_string()), + None => (orig_path.clone(), String::new()), + }; + let start_scope = match head.as_str() { + "crate" => scope.borrow().crate_root(), + "self" => Some(scope.clone()), + "super" => scope.borrow().parent.clone(), + other => { + if let Some(ext) = extern_prelude.get(other) { + Some(ext.root_scope.clone()) + } else { + scope + .borrow() + .visible_child(other) + .and_then(|d| d.as_module()) + .or_else(|| { + scope.borrow().crate_root().and_then(|root| { + root.borrow() + .visible_child(other) + .and_then(|d| d.as_module()) + }) + }) + } + } + }; + if let Some(start) = start_scope { + if let Some(def) = walk_segments(start, &tail, scope) { + if rc_use.target.borrow().contains_key(&import_path) + && rc_use.target.borrow().get(&import_path).unwrap().is_some() + { + continue; + } + rc_use.insert_target(import_path, Some(def)); + progress = true; + } + } + } + } + } + if !progress { + break; + } + } + table.build_symbol_tables(); +} + +#[allow(clippy::assigning_clones)] +fn scope_path(scope: &ScopeRef) -> Vec { + let mut segments = Vec::new(); + let mut cur = Some(scope.clone()); + while let Some(sc) = cur { + segments.push(sc.borrow().name.clone()); + cur = sc.borrow().parent.clone(); + } + segments.reverse(); + segments +} + +impl Visibility { + pub(crate) fn is_visible_from(&self, from_scope: &ScopeRef, owner_scope: &ScopeRef) -> bool { + if *self == Visibility::Public { + return true; + } + + let same_crate = + from_scope.borrow().crate_root_id() == owner_scope.borrow().crate_root_id(); + + match self { + Visibility::Private | Visibility::Inherited => { + let owner_path = scope_path(owner_scope); + let from_path = scope_path(from_scope); + from_path.starts_with(&owner_path) + } + Visibility::PubCrate => same_crate, + Visibility::PubSuper => { + if let Some(parent) = owner_scope.borrow().parent.clone() { + let super_path = scope_path(&parent); + let from_path = scope_path(from_scope); + from_path.starts_with(&super_path) + } else { + // owner is crate root -> `pub(super)` degenerates to private + false + } + } + Visibility::PubIn(path) => { + let target_mod = path + .split("::") + .filter(|s| !s.is_empty()) + .map(str::to_owned) + .collect::>(); + let from_path = scope_path(from_scope); + from_path.starts_with(&target_mod) + } + + Visibility::Public => unreachable!("handled above {:?}", self), + } + } +} + +fn walk_segments(mut scope: ScopeRef, path: &str, from_crate: &ScopeRef) -> Option { + if path.is_empty() { + return None; + } + let mut segments = path.split("::").filter(|s| !s.is_empty()).peekable(); + while let Some(seg) = segments.next() { + let is_last = segments.peek().is_none(); + let child = scope.borrow().visible_child(seg)?; + match child { + DefOrScopeRef::Module(sub_scope) => { + if is_last { + return None; + } + scope = sub_scope; + } + DefOrScopeRef::Definition(def) => { + if !def.visibility().is_visible_from(from_crate, &scope) { + return None; + } + if is_last { + return Some(def); + } + return None; + } + } + } + None +} + +#[allow(clippy::too_many_lines)] +fn process_statement(stmt: &Statement, scope: &ScopeRef, table: &mut SymbolTable) { + match stmt { + Statement::VariableDefinition(let_stmt) => { + let mut vty = if let Some(init) = &let_stmt.initial_value { + let ty = infer_expr_type(init, scope, table); + if ty != NodeType::Empty { + ty + } else if let Some((_, ty_str)) = let_stmt.pattern.kind.split_once(':') { + parse_str::(ty_str.trim()) + .map(|ty| NodeType::from_syn_item(&ty)) + .unwrap_or(NodeType::Empty) + } else { + NodeType::Empty + } + } else { + NodeType::Empty + }; + if let Some(def) = scope.borrow().lookup_def(&vty.name()) { + match def { + DefinitionRef::Ref(_, _) => {} + DefinitionRef::QualifiedName(qualified_name) => { + vty = NodeType::Path(qualified_name); + } + } + } + scope + .borrow_mut() + .insert_var(let_stmt.name.clone(), let_stmt.id, vty); + } + Statement::Expression(Expression::If(if_expr)) => { + for stmt in &if_expr.then_branch.statements { + process_statement(stmt, scope, table); + } + } + Statement::Block(block) => { + for stmt in &block.statements { + process_statement(stmt, scope, table); + } + } + Statement::Expression(Expression::EBlock(e_block)) => { + for stmt in &e_block.block.statements { + process_statement(stmt, scope, table); + } + } + _ => {} + } +} + +#[allow(clippy::too_many_lines)] +fn infer_expr_type(expr: &Expression, scope: &ScopeRef, table: &SymbolTable) -> NodeType { + match expr { + Expression::Identifier(id) => { + if id.name.contains("::") { + let (module, rest) = id.name.split_once("::").unwrap(); + if let Some(mod_scope) = table.mod_scopes.get(module) { + if let Some(def) = mod_scope.borrow().lookup_def(rest) { + let ty = match def { + DefinitionRef::Ref(_, Definition::Const(c)) => c.type_.to_type_node(), + DefinitionRef::Ref(_, Definition::Static(s)) => s.ty.to_type_node(), + DefinitionRef::Ref(_, Definition::Function(f)) => { + f.returns.borrow().clone() + } + DefinitionRef::Ref(_, Definition::AssocType(t)) => t.ty.to_type_node(), + DefinitionRef::Ref( + _, + Definition::Struct(s) | Definition::Contract(s), + ) => NodeType::Path(s.name.clone()), + DefinitionRef::Ref(_, Definition::Enum(e)) => { + NodeType::Path(e.name.clone()) + } + DefinitionRef::Ref(_, Definition::Union(u)) => { + NodeType::Path(u.name.clone()) + } + DefinitionRef::Ref(_, Definition::Module(m)) => { + NodeType::Path(m.name.clone()) + } + DefinitionRef::Ref(_, Definition::TraitAlias(ta)) => { + NodeType::Path(ta.name.clone()) + } + _ => return NodeType::Empty, + }; + return ty; + } + } + } + if let Some(v) = scope.borrow().lookup_symbol(&id.name) { + return v; + } + if let Some(def) = scope.borrow().lookup_def(&id.name) { + let ty_node = match def { + DefinitionRef::Ref(_, Definition::Const(c)) => c.type_.to_type_node(), + DefinitionRef::Ref(_, Definition::Static(s)) => s.ty.to_type_node(), + DefinitionRef::Ref(_, Definition::Function(f)) => f.returns.borrow().clone(), + DefinitionRef::Ref(_, Definition::AssocType(t)) => t.ty.to_type_node(), + DefinitionRef::Ref(_, Definition::Struct(s) | Definition::Contract(s)) => { + NodeType::Path(s.name.clone()) + } + DefinitionRef::Ref(_, Definition::Enum(e)) => NodeType::Path(e.name.clone()), + DefinitionRef::Ref(_, Definition::Union(u)) => NodeType::Path(u.name.clone()), + DefinitionRef::Ref(_, Definition::Module(m)) => NodeType::Path(m.name.clone()), + DefinitionRef::Ref(_, Definition::TraitAlias(ta)) => { + NodeType::Path(ta.name.clone()) + } + _ => return NodeType::Empty, + }; + return ty_node; + } + NodeType::Empty + } + Expression::Literal(lit_expr) => match &lit_expr.value { + Literal::Bool(_) => NodeType::Path("bool".to_string()), + Literal::Byte(_) => NodeType::Path("u8".to_string()), + Literal::Char(_) => NodeType::Path("char".to_string()), + Literal::Int(_) => NodeType::Path("i32".to_string()), + Literal::Float(_) => NodeType::Path("f64".to_string()), + Literal::String(_) => NodeType::Reference { + inner: Box::new(NodeType::Path("str".to_string())), + mutable: false, + is_explicit_reference: true, + }, + Literal::BString(_) => NodeType::Reference { + inner: Box::new(NodeType::Path("[u8]".to_string())), + mutable: false, + is_explicit_reference: true, + }, + Literal::CString(cs) => NodeType::Ptr { + inner: Box::new(NodeType::Path("c_char".to_string())), + mutable: cs.value.starts_with("*const"), + }, + }, + Expression::Binary(bin) => { + match bin.operator { + OperatorKind::Add + | OperatorKind::Sub + | OperatorKind::Mul + | OperatorKind::Div + | OperatorKind::Mod + | OperatorKind::BitXor + | OperatorKind::BitAnd + | OperatorKind::BitOr + | OperatorKind::Shl + | OperatorKind::Shr => { + let left_ty = infer_expr_type(&bin.left, scope, table); + let right_ty = infer_expr_type(&bin.right, scope, table); + if left_ty == right_ty { + return left_ty; + } + right_ty //Won't fix, fallback + } + OperatorKind::And + | OperatorKind::Or + | OperatorKind::Eq + | OperatorKind::Ne + | OperatorKind::Lt + | OperatorKind::Le + | OperatorKind::Ge + | OperatorKind::Gt => NodeType::Path("bool".to_string()), + OperatorKind::AddAssign + | OperatorKind::SubAssign + | OperatorKind::MulAssign + | OperatorKind::DivAssign + | OperatorKind::ModAssign + | OperatorKind::BitXorAssign + | OperatorKind::BitAndAssign + | OperatorKind::BitOrAssign + | OperatorKind::ShlAssign + | OperatorKind::ShrAssign => infer_expr_type(&bin.left, scope, table), + } + } + Expression::PrefixUnary(u) => infer_expr_type(&u.expression, scope, table), + Expression::FunctionCall(fc) => { + for def in scope.borrow().functions().chain(scope.borrow().methods()) { + if def.name == fc.function_name { + return def.returns.borrow().clone(); + } + } + if let Expression::Identifier(base) = &fc.expression { + let name = base.name.clone(); + if name == "Some" { + if let Some(arg) = &fc.parameters.first() { + let ty = infer_expr_type(arg, scope, table); + return NodeType::Path(format!("Option<{}>", ty.name())); + } + return NodeType::Path("Option<_>".to_string()); + } + if name == "None" { + return NodeType::Path("Option<_>".to_string()); + } + if name == "Ok" { + if let Some(arg) = &fc.parameters.first() { + let ty = infer_expr_type(arg, scope, table); + return NodeType::Path(format!("Result<{}, _>", ty.name())); + } + return NodeType::Path("Result<_, _>".to_string()); + } + if name.contains("::") { + let (module, rest) = name.split_once("::").unwrap(); + let module = module.trim(); + let rest = rest.trim(); + if let Some(mod_scope) = table.mod_scopes.get(module) { + for def in mod_scope + .borrow() + .functions() + .chain(mod_scope.borrow().methods()) + { + if def.name == rest { + return def.returns.borrow().clone(); + } + } + } + if ["Vec", "HashMap"].contains(&module) { + return NodeType::Reference { + inner: Box::new(NodeType::Path(module.to_string())), + mutable: false, + is_explicit_reference: false, + }; + } + } else if let Some(v) = scope.borrow().lookup_symbol(&name) { + return v; + } + } + NodeType::Empty + } + Expression::MethodCall(mc) => { + let base_ty = infer_expr_type(&mc.base, scope, table); + if base_ty == NodeType::Empty { + return NodeType::Empty; + } + let type_name_owned = base_ty.pure_name().clone(); + let type_name = type_name_owned.as_str(); + let base_def = if let Some(base_def_ref) = scope.borrow().lookup_def(type_name) { + match base_def_ref { + DefinitionRef::Ref(_, def) => Some(def), + DefinitionRef::QualifiedName(_) => None, + } + } else if let Some(d) = table.defs.iter().find(|((_, n), _)| n == type_name) { + Some(d.1.clone()) + } else { + return NodeType::Empty; + }; + if let Some(def) = base_def { + if let Some(def_scope) = table.get_scope_by_def_id(def.id()) { + if let Some(method) = + def_scope + .borrow() + .methods + .get(&def.name()) + .and_then(|methods| { + methods.iter().find(|f| f.name == mc.method_name).cloned() + }) + { + let ty_node_name = method.returns.borrow().pure_name(); + if let Some(def) = def_scope.borrow().lookup_def(&ty_node_name) { + match def { + DefinitionRef::Ref(qualified_name, _) => { + *method.returns.borrow_mut() = NodeType::Path( + def_scope + .borrow() + .relative_to_absolute_path(&qualified_name), + ); + } + DefinitionRef::QualifiedName(qualified_name) => { + *method.returns.borrow_mut() = NodeType::Path(qualified_name); + } + } + } + return method.returns.borrow().clone(); + } + } + } + NodeType::Empty + } + Expression::MemberAccess(ma) => { + let base_ty = infer_expr_type(&ma.base, scope, table); + if let Some(def_ref) = scope.borrow().lookup_def(&base_ty.pure_name()) { + match def_ref { + DefinitionRef::Ref(_, Definition::Struct(s) | Definition::Contract(s)) => { + for (field, fty) in &s.fields { + if &ma.member_name == field { + return fty.to_type_node(); + } + } + } + DefinitionRef::Ref(_, Definition::Enum(e)) => { + for variant in &e.variants { + if &ma.member_name == variant { + return NodeType::Path(base_ty.name().clone()); + } + } + } + DefinitionRef::QualifiedName(qn) => { + let def_ref_o = scope.borrow().lookup_def(&qn); + if let Some(def) = def_ref_o { + if let DefinitionRef::Ref( + _, + Definition::Struct(s) | Definition::Contract(s), + ) = &def + { + for (field, fty) in &s.fields { + if &ma.member_name == field { + return fty.to_type_node(); + } + } + } + if let DefinitionRef::Ref(_, Definition::Enum(e)) = def { + for variant in &e.variants { + if &ma.member_name == variant { + return NodeType::Path(base_ty.name().clone()); + } + } + } + } else { + return NodeType::Path(qn.clone()); + } + } + DefinitionRef::Ref(_, _) => {} + } + } + + NodeType::Empty + } + Expression::Return(r) => { + if let Some(e) = &r.expression { + infer_expr_type(e, scope, table) + } else { + NodeType::Tuple(Vec::new()) + } + } + Expression::Cast(c) => { + let mut ty_node = c.target_type.to_type_node(); + if ty_node.is_self() { + if let Some(self_ty) = scope.borrow().resolve_self() { + ty_node = match ty_node { + NodeType::Path(ref name) if name == "Self" => { + NodeType::Path(self_ty.name()) + } + NodeType::Reference { + inner: old_inner, + mutable, + is_explicit_reference, + } if old_inner.name() == "Self" => NodeType::Reference { + inner: Box::new(NodeType::Path(self_ty.name())), + mutable, + is_explicit_reference, + }, + NodeType::Ptr { inner, mutable } if inner.name() == "Self" => { + NodeType::Ptr { + inner: Box::new(NodeType::Path(self_ty.name())), + mutable, + } + } + other => other, + }; + } + } + ty_node + } + Expression::Tuple(t) => { + let mut types = Vec::new(); + for e in &t.elements { + types.push(infer_expr_type(e, scope, table)); + } + NodeType::Tuple(types) + } + _ => NodeType::Empty, + } +} + +// #[cfg(test)] +// mod tests { +// use super::*; +// use crate::{Codebase, OpenState, directive::Directive, node_type::NodeKind}; + +// #[must_use] +// #[allow(clippy::missing_panics_doc)] +// pub(crate) fn resolve_path( +// table: &SymbolTable, +// scope_id: u32, +// path: &str, +// ) -> Option { +// let mut parts = path +// .split("::") +// .filter(|s| !s.is_empty()) +// .collect::>(); +// if parts.is_empty() { +// return None; +// } +// let name = parts.pop().unwrap(); +// if parts.is_empty() { +// if let Some(scope) = table +// .mod_scopes +// .iter() +// .find(|(_, v)| v.borrow().id == scope_id) +// .map(|(_, v)| v) +// { +// return scope.borrow().lookup_def(name); +// } +// } +// let module_path = parts.join("::"); +// for (mod_scope_name, mod_scope) in &table.mod_scopes { +// if *mod_scope_name == module_path { +// for (def_name, def) in &mod_scope.borrow().definitions { +// if def_name == name { +// return Some(DefinitionRef::Ref(def_name.clone(), def.clone())); +// } +// } +// } +// } +// None +// } + +// fn build_table_and_uses(src: &str) -> (SymbolTable, Vec>) { +// let mut cb = Codebase::::default(); +// let content = src.to_string(); +// let mut data = HashMap::new(); +// data.insert( +// format!("test{0}test.rs", std::path::MAIN_SEPARATOR), +// content, +// ); +// let sealed = cb.build_api(&data).unwrap(); +// let table = sealed.symbol_table.clone(); +// let file = sealed.files().next().unwrap().clone(); +// let uses = file +// .children +// .borrow() +// .iter() +// .filter_map(|node| { +// if let NodeKind::Directive(Directive::Use(u)) = node { +// Some(u.clone()) +// } else { +// None +// } +// }) +// .collect::>(); +// (table, uses) +// } + +// #[test] +// fn test_resolve_path_basic() { +// let src = r" +// mod a { +// pub struct S; +// pub mod b { +// pub enum E { A } +// } +// } +// "; +// let (table, _) = build_table_and_uses(src); +// let scope_id = table.mod_scopes.iter().next().unwrap().1.borrow().id; +// assert!( +// resolve_path( +// &table, +// scope_id, +// "soroban_security_detectors_sdk::test::a::S" +// ) +// .is_some() +// ); +// assert!( +// resolve_path( +// &table, +// scope_id, +// "soroban_security_detectors_sdk::test::a::b::E" +// ) +// .is_some() +// ); +// assert!( +// resolve_path( +// &table, +// scope_id, +// "soroban_security_detectors_sdk::test::a::X" +// ) +// .is_none() +// ); +// assert!(resolve_path(&table, scope_id, "").is_none()); +// } + +// #[test] +// fn test_import_target_resolution() { +// let mut cb = Codebase::::default(); +// let file1 = r" +// pub type MyType = u8; +// pub mod sub { +// pub struct SubType; +// } +// " +// .to_string(); + +// let file2 = r" +// use file1::MyType; +// use file1::sub::SubType as Renamed; +// " +// .to_string(); +// let mut data = HashMap::new(); +// data.insert(format!("test{0}file1.rs", std::path::MAIN_SEPARATOR), file1); +// data.insert(format!("test{0}file2.rs", std::path::MAIN_SEPARATOR), file2); +// let sealed = cb.build_api(&data).unwrap(); +// let table = sealed.symbol_table.clone(); + +// let file2 = sealed.files().find(|f| f.name == "file2.rs").unwrap(); +// let uses = file2 +// .children +// .borrow() +// .iter() +// .filter_map(|node| { +// if let NodeKind::Directive(Directive::Use(u)) = node { +// Some(u.clone()) +// } else { +// None +// } +// }) +// .collect::>(); +// assert_eq!(uses.len(), 2); + +// let full1 = "file1::MyType".to_string(); +// let u1 = uses +// .iter() +// .find(|u| u.imported_types == vec![full1.clone()]) +// .unwrap(); +// let DefinitionRef::Ref(_, my_def) = resolve_path( +// &table, +// table +// .mod_scopes +// .iter() +// .find(|s| !s.1.borrow().definitions.is_empty()) +// .unwrap() +// .1 +// .borrow() +// .id, +// "soroban_security_detectors_sdk::file1::MyType", +// ) +// .unwrap() else { +// panic!("Expected a reference to a definition"); +// }; +// let binding = u1.target.borrow(); +// println!("Use target: {binding:?}, checking key {full1:?}"); +// let found = binding.get(&full1); +// assert_eq!(found.unwrap().as_ref().unwrap().id(), my_def.id()); + +// let full2 = "file1::sub::SubType%Renamed".to_string(); +// let u2 = uses +// .iter() +// .find(|u| u.imported_types == vec![full2.clone()]) +// .unwrap(); +// let DefinitionRef::Ref(_, sub) = resolve_path( +// &table, +// table.mod_scopes.iter().next().unwrap().1.borrow().id, +// "soroban_security_detectors_sdk::file1::sub::SubType", +// ) +// .unwrap() else { +// panic!("Expected a reference to a definition"); +// }; +// assert_eq!(u2.target.borrow().get(&full2), Some(&Some(sub))); +// } + +// #[test] +// fn test_lookup_symbol_origin_basic() { +// let src = r" +// pub const C: u32 = 5; + +// fn f() { +// let x = C; +// let y = x; +// } +// "; +// let (table, _) = build_table_and_uses(src); +// let root_scope = table +// .mod_scopes +// .get("soroban_security_detectors_sdk::test") +// .unwrap() +// .borrow() +// .id; +// let origin_c_root = table.lookup_symbol_origin(root_scope, "C").unwrap(); +// if let crate::node_type::NodeKind::Definition(ref def) = origin_c_root { +// assert_eq!(def.name(), "C"); +// } else { +// panic!("Expected Definition for C"); +// } +// let f_def = table.get_function_by_name(root_scope, "f").unwrap(); +// let f_scope = f_def.id(); +// let origin_c = table.lookup_symbol_origin(f_scope, "C").unwrap(); +// assert_eq!(origin_c.id(), origin_c_root.id()); +// let origin_x = table.lookup_symbol_origin(f_scope, "x").unwrap(); +// if let crate::node_type::NodeKind::Statement(stmt) = origin_x.clone() { +// if let Statement::Let(let_stmt) = stmt { +// assert_eq!(let_stmt.name, "x"); +// } else { +// panic!("Expected Let statement for x"); +// } +// } else { +// panic!("Expected Statement variant for x origin"); +// } +// let origin_y = table.lookup_symbol_origin(f_scope, "y").unwrap(); +// if let crate::node_type::NodeKind::Statement(stmt) = origin_y.clone() { +// if let Statement::Let(let_stmt) = stmt { +// assert_eq!(let_stmt.name, "y"); +// } else { +// panic!("Expected Let statement for y"); +// } +// } else { +// panic!("Expected Statement variant for y origin"); +// } +// let origin_x_again = table.lookup_symbol_origin(f_scope, "x").unwrap(); +// assert_eq!(origin_x_again.id(), origin_x.id()); +// assert!(table.lookup_symbol_origin(f_scope, "z").is_none()); +// } + +// #[test] +// fn test_lookup_symbol_origin_parameter() { +// let src = r" +// fn g(p: u32) { +// let q = p; +// } +// "; +// let (table, _) = build_table_and_uses(src); +// let root_scope = table +// .mod_scopes +// .get("soroban_security_detectors_sdk::test") +// .unwrap() +// .borrow() +// .id; +// let g_def = table.get_function_by_name(root_scope, "g").unwrap(); +// let g_scope = g_def.id(); +// let origin_p = table.lookup_symbol_origin(g_scope, "p").unwrap(); +// if let crate::node_type::NodeKind::Misc(Misc::FnParameter(param)) = origin_p { +// assert_eq!(param.name, "p"); +// } else { +// panic!("Expected FnParameter origin for p"); +// } +// } + +// #[test] +// fn test_lookup_symbol_origin_parameter_single_call_constant() { +// let src = r" +// const C: u32 = 5; + +// fn g(p: u32) { +// let q = p; +// } + +// fn f() { +// g(C); +// } +// "; +// let (table, _) = build_table_and_uses(src); +// let root_scope = table +// .mod_scopes +// .get("soroban_security_detectors_sdk::test") +// .unwrap() +// .borrow() +// .id; +// let g_scope = table.get_function_by_name(root_scope, "g").unwrap().id(); +// let origin_c = table.lookup_symbol_origin(root_scope, "C").unwrap(); +// let origin_p = table.lookup_symbol_origin(g_scope, "p").unwrap(); +// assert_eq!(origin_p.id(), origin_c.id()); +// } + +// #[test] +// fn test_lookup_symbol_origin_parameter_single_call_variable() { +// let src = r" +// fn g(p: u32) { +// let q = p; +// } + +// fn f() { +// let x = 42; +// g(x); +// } +// "; +// let (table, _) = build_table_and_uses(src); +// let root_scope = table +// .mod_scopes +// .get("soroban_security_detectors_sdk::test") +// .unwrap() +// .borrow() +// .id; +// let g_scope = table.get_function_by_name(root_scope, "g").unwrap().id(); +// let f_scope = table.get_function_by_name(root_scope, "f").unwrap().id(); +// let origin_x = table.lookup_symbol_origin(f_scope, "x").unwrap(); +// let origin_p = table.lookup_symbol_origin(g_scope, "p").unwrap(); +// assert_eq!(origin_p.id(), origin_x.id()); +// } + +// #[test] +// fn test_lookup_symbol_origin_parameter_multiple_calls() { +// let src = r" +// fn g(p: u32) { +// let q = p; +// } + +// fn a() { g(1); } +// fn b() { g(2); } +// "; +// let (table, _) = build_table_and_uses(src); +// let root_scope = table +// .mod_scopes +// .get("soroban_security_detectors_sdk::test") +// .unwrap() +// .borrow() +// .id; +// let g_scope = table.get_function_by_name(root_scope, "g").unwrap().id(); +// let origin_p = table.lookup_symbol_origin(g_scope, "p").unwrap(); +// if let crate::node_type::NodeKind::Misc(Misc::FnParameter(param)) = origin_p { +// assert_eq!(param.name, "p"); +// } else { +// panic!("Expected FnParameter for multiple calls"); +// } +// } + +// #[test] +// fn test_lookup_symbol_origin_parameter_single_call_nested_block() { +// let src = r" +// const C: u32 = 5; + +// fn h(p: u32) { +// let q = p; +// } + +// fn f() { +// if true { +// let x = C; +// h(x); +// } +// } +// "; +// let (table, _) = build_table_and_uses(src); +// let root_scope = table +// .mod_scopes +// .get("soroban_security_detectors_sdk::test") +// .unwrap() +// .borrow() +// .id; +// let h_scope = table.get_function_by_name(root_scope, "h").unwrap().id(); +// let f_scope = table.get_function_by_name(root_scope, "f").unwrap().id(); +// let origin_x = table.lookup_symbol_origin(f_scope, "x").unwrap(); +// let origin_p = table.lookup_symbol_origin(h_scope, "p").unwrap(); +// assert_eq!(origin_p.id(), origin_x.id()); +// } + +// #[test] +// fn test_import_across_file_scopes() { +// let mut cb = Codebase::::default(); +// let file1 = r" +// pub struct AStruct; +// " +// .to_string(); + +// let file2 = r" +// use file1::AStruct; +// " +// .to_string(); +// let mut data = HashMap::new(); +// data.insert(format!("test{0}file1.rs", std::path::MAIN_SEPARATOR), file1); +// data.insert(format!("test{0}file2.rs", std::path::MAIN_SEPARATOR), file2); + +// let sealed = cb.build_api(&data).unwrap(); +// let table = sealed.symbol_table.clone(); + +// let file2 = sealed.files().find(|f| f.name == "file2.rs").unwrap(); +// let uses = file2 +// .children +// .borrow() +// .iter() +// .filter_map(|node| { +// if let NodeKind::Directive(Directive::Use(u)) = node { +// Some(u.clone()) +// } else { +// None +// } +// }) +// .collect::>(); +// assert_eq!(uses.len(), 1); + +// let u = &uses[0]; +// let scope = table +// .mod_scopes +// .get("soroban_security_detectors_sdk::file2") +// .expect("Module scope for file2 not found"); +// let DefinitionRef::Ref(_, expected) = resolve_path( +// &table, +// scope.borrow().id, +// "soroban_security_detectors_sdk::file1::AStruct", +// ) +// .unwrap() else { +// panic!("Expected a reference to a definition"); +// }; +// let key = &u.imported_types[0]; +// assert_eq!(u.target.borrow().get(key), Some(&Some(expected))); +// } +// } From 24fa9537464cde6d7b3f422f06915a5c9cb2bdc4 Mon Sep 17 00:00:00 2001 From: Georgii Plotnikov Date: Mon, 15 Dec 2025 20:58:13 +0900 Subject: [PATCH 3/4] Enhance HIR structure by adding SourceFile support and refactoring related components --- core/ast/src/arena.rs | 3 +- core/ast/src/builder.rs | 3 +- core/ast/src/enums_impl.rs | 94 ++++++++++- core/ast/src/lib.rs | 1 - core/ast/src/node_type.rs | 311 ------------------------------------- core/ast/src/nodes.rs | 7 +- core/ast/src/nodes_impl.rs | 77 --------- core/hir/src/arena.rs | 6 +- core/hir/src/nodes.rs | 44 +++++- 9 files changed, 148 insertions(+), 398 deletions(-) delete mode 100644 core/ast/src/node_type.rs diff --git a/core/ast/src/arena.rs b/core/ast/src/arena.rs index 589160d1..357bdcf3 100644 --- a/core/ast/src/arena.rs +++ b/core/ast/src/arena.rs @@ -1,11 +1,12 @@ use std::{collections::HashMap, rc::Rc}; -use crate::nodes::{AstNode, Definition, TypeDefinition}; +use crate::nodes::{AstNode, Definition, SourceFile, TypeDefinition}; #[derive(Default, Clone)] pub struct Arena { pub(crate) nodes: HashMap, pub(crate) node_routes: Vec, + pub sources: Vec, } impl Arena { diff --git a/core/ast/src/builder.rs b/core/ast/src/builder.rs index f05e6069..91c10e78 100644 --- a/core/ast/src/builder.rs +++ b/core/ast/src/builder.rs @@ -104,8 +104,9 @@ impl<'a> Builder<'a, InitState> { } res.push(ast); } + self.arena.sources = res; Ok(Builder { - arena: Arena::default(), + arena: std::mem::take(&mut self.arena), source_code: Vec::new(), _state: PhantomData, }) diff --git a/core/ast/src/enums_impl.rs b/core/ast/src/enums_impl.rs index 49cbc49c..5041f4d6 100644 --- a/core/ast/src/enums_impl.rs +++ b/core/ast/src/enums_impl.rs @@ -1,4 +1,96 @@ -use crate::nodes::Type; +use crate::nodes::{BlockType, Definition, Expression, Statement, Type}; + +impl Definition { + #[must_use] + pub fn name(&self) -> String { + match self { + Definition::Spec(spec) => spec.name.name(), + Definition::Struct(struct_def) => struct_def.name.name(), + Definition::Enum(enum_def) => enum_def.name.name(), + Definition::Constant(const_def) => const_def.name.name(), + Definition::Function(func_def) => func_def.name.name(), + Definition::ExternalFunction(ext_func_def) => ext_func_def.name.name(), + Definition::Type(type_def) => type_def.name.name(), + } + } +} + +impl BlockType { + #[must_use] + pub fn statements(&self) -> Vec { + match self { + BlockType::Block(block) + | BlockType::Forall(block) + | BlockType::Assume(block) + | BlockType::Exists(block) + | BlockType::Unique(block) => block.statements.clone(), + } + } + #[must_use] + pub fn is_non_det(&self) -> bool { + match self { + BlockType::Block(block) => block + .statements + .iter() + .any(super::nodes::Statement::is_non_det), + _ => true, + } + } + #[must_use] + pub fn is_void(&self) -> bool { + let fn_find_ret_stmt = |statements: &Vec| -> bool { + for stmt in statements { + match stmt { + Statement::Return(_) => return true, + Statement::Block(block_type) => { + if block_type.is_void() { + return true; + } + } + _ => {} + } + } + false + }; + !fn_find_ret_stmt(&self.statements()) + } +} + +impl Statement { + #[must_use] + pub fn is_non_det(&self) -> bool { + match self { + Statement::Block(block_type) => !matches!(block_type, BlockType::Block(_)), + Statement::Expression(expr_stmt) => expr_stmt.is_non_det(), + Statement::Return(ret_stmt) => ret_stmt.expression.borrow().is_non_det(), + Statement::Loop(loop_stmt) => loop_stmt + .condition + .borrow() + .as_ref() + .is_some_and(super::nodes::Expression::is_non_det), + Statement::If(if_stmt) => { + if_stmt.condition.borrow().is_non_det() + || if_stmt.if_arm.is_non_det() + || if_stmt + .else_arm + .as_ref() + .is_some_and(super::nodes::BlockType::is_non_det) + } + Statement::VariableDefinition(var_def) => var_def + .value + .as_ref() + .is_some_and(|value| value.borrow().is_non_det()), + _ => false, + } + } +} + +impl Expression { + #[must_use] + pub fn is_non_det(&self) -> bool { + matches!(self, Expression::Uzumaki(_)) + } +} impl Type { pub(crate) fn is_unit_type(&self) -> bool { diff --git a/core/ast/src/lib.rs b/core/ast/src/lib.rs index 8df1afac..fa3abec1 100644 --- a/core/ast/src/lib.rs +++ b/core/ast/src/lib.rs @@ -2,6 +2,5 @@ pub mod arena; pub mod builder; pub(crate) mod enums_impl; -pub mod node_type; pub mod nodes; pub(crate) mod nodes_impl; diff --git a/core/ast/src/node_type.rs b/core/ast/src/node_type.rs deleted file mode 100644 index a762bc31..00000000 --- a/core/ast/src/node_type.rs +++ /dev/null @@ -1,311 +0,0 @@ -//! AST node types for Rust type annotations. -//! -//! Defines the `NodeType` enum representing parsed type expressions such as paths, -//! references, pointers, tuples, arrays, and more. - -use std::rc::Rc; - -use crate::nodes::{ - Definition, Directive, EnumDefinition, Expression, FunctionCallExpression, FunctionDefinition, - Literal, Location, SourceFile, Statement, StructDefinition, Type, -}; -pub type RcFile = Rc; -pub type RcContract = Rc; -pub type RcFunction = Rc; -pub type RcExpression = Rc; -pub type RcFunctionCall = Rc; - -pub type RcEnum = Rc; -pub type RcStruct = Rc; - -#[derive(Clone, PartialEq, Eq, Debug, Default, serde::Serialize, serde::Deserialize)] -pub enum NodeType { - #[default] - Empty, - /// A named type or path, including any generics as represented in the token stream - Path(String), - /// A reference `&T` or `&mut T`, with explicit flag - Reference { - inner: Box, - mutable: bool, - is_explicit_reference: bool, - }, - /// A raw pointer `*const T` or `*mut T` - Ptr { inner: Box, mutable: bool }, - /// A tuple type `(T1, T2, ...)` - Tuple(Vec), - /// An array type `[T; len]`, with optional length if parseable - Array { - inner: Box, - len: Option, - }, - /// A slice type `[T]` - Slice(Box), - /// A bare function pointer `fn(a, b) -> R` - BareFn { - inputs: Vec, - output: Box, - }, - /// A generic type annotation, e.g., `Option`, `Result` - Generic { - base: Box, - args: Vec, - }, - /// A trait object type `dyn Trait1 + Trait2` - TraitObject(Vec), - /// An `impl Trait` type - ImplTrait(Vec), - Closure { - inputs: Vec, - output: Box, - }, -} - -impl NodeType { - #[must_use] - pub fn name(&self) -> String { - match self { - NodeType::Path(name) => name.clone(), - NodeType::Reference { - inner, - is_explicit_reference, - .. - } => { - if *is_explicit_reference { - format!("&{}", inner.name()) - } else { - inner.name() - } - } - NodeType::Ptr { inner, mutable } => { - let star = if *mutable { "*mut" } else { "*const" }; - format!("{} {}", star, inner.name()) - } - NodeType::Tuple(elems) => format!( - "({})", - elems - .iter() - .map(NodeType::name) - .collect::>() - .join(", ") - ), - NodeType::Array { inner, len } => format!( - "[{}; {}]", - inner.name(), - len.map_or("..".to_string(), |l| l.to_string()) - ), - NodeType::Slice(inner) => format!("[{}]", inner.name()), - NodeType::BareFn { inputs, output } => { - let mut result = inputs - .iter() - .map(NodeType::name) - .collect::>() - .join(", "); - if result.is_empty() { - result = "_".to_string(); - } - let output = if output.name().is_empty() { - "_".to_string() - } else { - output.name() - }; - format!("fn({result}) -> {output}") - } - NodeType::Closure { inputs, output } => { - let mut result = inputs - .iter() - .map(NodeType::name) - .collect::>() - .join(", "); - if result.is_empty() { - result = "_".to_string(); - } - let output = if output.name().is_empty() { - "_".to_string() - } else { - output.name() - }; - format!("{result} || -> {output}") - } - NodeType::Generic { base, args } => format!( - "{}<{}>", - base.name(), - args.iter() - .map(NodeType::name) - .collect::>() - .join(", ") - ), - NodeType::TraitObject(bounds) => format!("dyn {}", bounds.join(" + ")), - NodeType::ImplTrait(bounds) => format!("impl {}", bounds.join(" + ")), - NodeType::Empty => String::from("_"), - } - } - - #[must_use] - pub fn pure_name(&self) -> String { - match self { - NodeType::Path(name) => name.clone(), - NodeType::Reference { inner, .. } - | NodeType::Ptr { inner, .. } - | NodeType::Array { inner, len: _ } - | NodeType::Slice(inner) => inner.pure_name(), - NodeType::Tuple(elems) => format!( - "({})", - elems - .iter() - .map(NodeType::pure_name) - .collect::>() - .join(", ") - ), - NodeType::BareFn { inputs, output } => { - let mut result = inputs - .iter() - .map(NodeType::pure_name) - .collect::>() - .join(", "); - if result.is_empty() { - result = "_".to_string(); - } - let output = output.pure_name(); - format!("fn({result}) -> {output}") - } - NodeType::Closure { inputs, output } => { - let mut result = inputs - .iter() - .map(NodeType::pure_name) - .collect::>() - .join(", "); - if result.is_empty() { - result = "_".to_string(); - } - let output = output.pure_name(); - format!("{result} || -> {output}") - } - NodeType::Generic { base, args } => format!( - "{}<{}>", - base.pure_name(), - args.iter() - .map(NodeType::pure_name) - .collect::>() - .join(", ") - ), - NodeType::TraitObject(bounds) | NodeType::ImplTrait(bounds) => bounds.join(" + "), - NodeType::Empty => String::from("_"), - } - } - - #[must_use] - pub fn is_self(&self) -> bool { - match self { - NodeType::Path(name) => name.to_lowercase() == "self", - NodeType::Reference { inner, .. } - | NodeType::Ptr { inner, .. } - | NodeType::Array { inner, .. } - | NodeType::Slice(inner) => inner.is_self(), - NodeType::Tuple(elems) => elems.iter().any(NodeType::is_self), - NodeType::BareFn { inputs, output } | NodeType::Closure { inputs, output } => { - inputs.iter().any(NodeType::is_self) || output.is_self() - } - NodeType::Generic { base, args } => { - base.is_self() || args.iter().any(NodeType::is_self) - } - NodeType::TraitObject(bounds) | NodeType::ImplTrait(bounds) => { - bounds.iter().any(|b| b.to_lowercase() == "self") - } - NodeType::Empty => false, - } - } - - #[allow(clippy::assigning_clones)] - pub fn replace_path(&mut self, new_path: String) { - match self { - NodeType::Path(_) => { - *self = NodeType::Path(new_path); - } - NodeType::Reference { inner, .. } - | NodeType::Ptr { inner, .. } - | NodeType::Array { inner, .. } - | NodeType::Slice(inner) => { - inner.replace_path(new_path); - } - NodeType::Tuple(elems) => { - for elem in elems { - elem.replace_path(new_path.clone()); - } - } - NodeType::BareFn { inputs, output } | NodeType::Closure { inputs, output } => { - for input in inputs { - input.replace_path(new_path.clone()); - } - output.replace_path(new_path); - } - NodeType::Generic { base, args } => { - base.replace_path(new_path.clone()); - for arg in args { - arg.replace_path(new_path.clone()); - } - } - NodeType::TraitObject(bounds) | NodeType::ImplTrait(bounds) => { - for bound in bounds.iter_mut() { - if bound.to_lowercase() == "self" { - *bound = new_path.clone(); - } - } - } - NodeType::Empty => {} - } - } -} - -#[derive(Clone, Debug)] -pub enum NodeKind { - File(Rc), - Directive(Directive), - Definition(Definition), - Statement(Statement), - Expression(Expression), - Literal(Literal), - Type(Type), -} - -impl NodeKind { - #[must_use] - pub fn id(&self) -> u32 { - match self { - NodeKind::File(f) => f.id, - NodeKind::Definition(d) => d.id(), - NodeKind::Directive(d) => d.id(), - NodeKind::Statement(s) => s.id(), - NodeKind::Expression(e) => e.id(), - NodeKind::Literal(l) => l.id(), - NodeKind::Type(t) => t.id(), - } - } - - #[must_use] - #[allow(clippy::cast_possible_truncation)] - pub fn location(&self) -> Location { - match self { - NodeKind::File(f) => f.location().clone(), - NodeKind::Definition(d) => d.location().clone(), - NodeKind::Directive(d) => d.location(), - NodeKind::Statement(s) => s.location(), - NodeKind::Expression(e) => e.location(), - NodeKind::Literal(l) => l.location(), - NodeKind::Type(t) => t.location(), - } - } - - #[must_use] - pub fn children(&self) -> Vec { - match self { - NodeKind::File(file) => file.children(), - NodeKind::Definition(definition) => definition.children(), - NodeKind::Directive(directive) => directive.children(), - NodeKind::Statement(statement) => statement.children(), - NodeKind::Expression(expression) => expression.children(), - NodeKind::Literal(literal) => literal.children(), - NodeKind::Type(ty) => ty.children(), - } - } -} diff --git a/core/ast/src/nodes.rs b/core/ast/src/nodes.rs index 1422db17..5af0247e 100644 --- a/core/ast/src/nodes.rs +++ b/core/ast/src/nodes.rs @@ -95,6 +95,7 @@ macro_rules! ast_nodes { }; } +#[macro_export] macro_rules! ast_enum { ( $(#[$outer:meta])* @@ -308,7 +309,6 @@ pub enum OperatorKind { } ast_nodes! { - pub struct SourceFile { pub directives: Vec, pub definitions: Vec, @@ -518,7 +518,6 @@ ast_nodes! { pub struct TypeArray { pub element_type: Type, - pub size: Option + pub size: Option, } - -} +} \ No newline at end of file diff --git a/core/ast/src/nodes_impl.rs b/core/ast/src/nodes_impl.rs index 2247f8f9..5fd3254f 100644 --- a/core/ast/src/nodes_impl.rs +++ b/core/ast/src/nodes_impl.rs @@ -106,83 +106,6 @@ impl SourceFile { } } -impl BlockType { - #[must_use] - pub fn statements(&self) -> Vec { - match self { - BlockType::Block(block) - | BlockType::Forall(block) - | BlockType::Assume(block) - | BlockType::Exists(block) - | BlockType::Unique(block) => block.statements.clone(), - } - } - #[must_use] - pub fn is_non_det(&self) -> bool { - match self { - BlockType::Block(block) => block - .statements - .iter() - .any(super::nodes::Statement::is_non_det), - _ => true, - } - } - #[must_use] - pub fn is_void(&self) -> bool { - let fn_find_ret_stmt = |statements: &Vec| -> bool { - for stmt in statements { - match stmt { - Statement::Return(_) => return true, - Statement::Block(block_type) => { - if block_type.is_void() { - return true; - } - } - _ => {} - } - } - false - }; - !fn_find_ret_stmt(&self.statements()) - } -} - -impl Statement { - #[must_use] - pub fn is_non_det(&self) -> bool { - match self { - Statement::Block(block_type) => !matches!(block_type, BlockType::Block(_)), - Statement::Expression(expr_stmt) => expr_stmt.is_non_det(), - Statement::Return(ret_stmt) => ret_stmt.expression.borrow().is_non_det(), - Statement::Loop(loop_stmt) => loop_stmt - .condition - .borrow() - .as_ref() - .is_some_and(super::nodes::Expression::is_non_det), - Statement::If(if_stmt) => { - if_stmt.condition.borrow().is_non_det() - || if_stmt.if_arm.is_non_det() - || if_stmt - .else_arm - .as_ref() - .is_some_and(super::nodes::BlockType::is_non_det) - } - Statement::VariableDefinition(var_def) => var_def - .value - .as_ref() - .is_some_and(|value| value.borrow().is_non_det()), - _ => false, - } - } -} - -impl Expression { - #[must_use] - pub fn is_non_det(&self) -> bool { - matches!(self, Expression::Uzumaki(_)) - } -} - impl UseDirective { #[must_use] pub fn new( diff --git a/core/hir/src/arena.rs b/core/hir/src/arena.rs index 3072e9fe..dd1e4ed7 100644 --- a/core/hir/src/arena.rs +++ b/core/hir/src/arena.rs @@ -1,2 +1,6 @@ +use crate::nodes::SourceFile; + #[derive(Default, Clone)] -pub struct Arena {} +pub struct Arena { + pub sources: Vec, +} \ No newline at end of file diff --git a/core/hir/src/nodes.rs b/core/hir/src/nodes.rs index 29adac5e..85fa01a4 100644 --- a/core/hir/src/nodes.rs +++ b/core/hir/src/nodes.rs @@ -2,6 +2,7 @@ use std::rc::Rc; use crate::type_info::TypeInfo; +#[derive(Clone, Debug)] pub enum Directive { Use(Rc), } @@ -17,6 +18,7 @@ pub enum Definition { Type(Rc), } +#[derive(Clone, Debug)] pub enum BlockType { Block(Rc), Assume(Rc), @@ -25,6 +27,7 @@ pub enum BlockType { Unique(Rc), } +#[derive(Clone, Debug)] pub enum Statement { Block(BlockType), Expression(Expression), @@ -39,6 +42,7 @@ pub enum Statement { ConstantDefinition(Rc), } +#[derive(Clone, Debug)] pub enum Expression { ArrayIndexAccess(Rc), Binary(Rc), @@ -49,10 +53,11 @@ pub enum Expression { PrefixUnary(Rc), Parenthesized(Rc), Literal(Literal), - TypeInfo(TypeInfo), //TODO: need it + TypeInfo(TypeInfo), Uzumaki(Rc), } +#[derive(Clone, Debug)] pub enum Literal { Array(Rc), Bool(Rc), @@ -61,6 +66,7 @@ pub enum Literal { Unit(Rc), } +#[derive(Clone, Debug)] pub enum ArgumentType { SelfReference(Rc), IgnoreArgument(Rc), @@ -101,6 +107,7 @@ pub enum OperatorKind { Shr, } +#[derive(Clone, Debug)] pub struct SourceFile { pub directives: Vec, pub definitions: Vec, @@ -113,33 +120,39 @@ pub struct UseDirective { pub from: Option, } +#[derive(Clone, Debug)] pub struct SpecDefinition { pub name: String, pub definitions: Vec, } +#[derive(Clone, Debug)] pub struct StructDefinition { pub name: String, pub fields: Vec>, pub methods: Vec>, } +#[derive(Clone, Debug)] pub struct StructField { pub name: String, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct EnumDefinition { pub name: String, pub variants: Vec, } +#[derive(Clone, Debug)] pub struct ConstantDefinition { pub name: String, pub type_info: TypeInfo, pub value: Literal, } +#[derive(Clone, Debug)] pub struct FunctionDefinition { pub name: String, pub type_parameters: Option>, @@ -148,56 +161,68 @@ pub struct FunctionDefinition { pub body: BlockType, } +#[derive(Clone, Debug)] pub struct ExternalFunctionDefinition { pub name: String, pub arguments: Option>, pub returns: Option, } +#[derive(Clone, Debug)] pub struct TypeDefinition { pub name: String, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct Argument { pub name: String, pub is_mut: bool, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct SelfReference { pub is_mut: bool, } +#[derive(Clone, Debug)] pub struct IgnoreArgument { pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct Block { pub statements: Vec, } +#[derive(Clone, Debug)] pub struct ExpressionStatement { pub expression: Expression, } +#[derive(Clone, Debug)] pub struct ReturnStatement { pub expression: Expression, } +#[derive(Clone, Debug)] pub struct LoopStatement { pub condition: Option, pub body: BlockType, } +#[derive(Clone, Debug)] pub struct BreakStatement {} +#[derive(Clone, Debug)] pub struct IfStatement { pub condition: Expression, pub if_arm: BlockType, pub else_arm: Option, } +#[derive(Clone, Debug)] pub struct VariableDefinitionStatement { pub name: String, pub type_info: TypeInfo, @@ -205,34 +230,40 @@ pub struct VariableDefinitionStatement { pub is_uzumaki: bool, } +#[derive(Clone, Debug)] pub struct TypeDefinitionStatement { pub name: String, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct AssignStatement { pub left: Expression, pub right: Expression, } +#[derive(Clone, Debug)] pub struct ArrayIndexAccessExpression { pub array: Expression, pub index: Expression, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct MemberAccessExpression { pub expression: Expression, pub name: String, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct TypeMemberAccessExpression { pub expression: Expression, pub name: String, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct FunctionCallExpression { pub name: String, pub function: Expression, @@ -241,31 +272,37 @@ pub struct FunctionCallExpression { pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct StructExpression { pub name: String, pub fields: Option>, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct UzumakiExpression { pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct PrefixUnaryExpression { pub expression: Expression, pub operator: UnaryOperatorKind, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct AssertStatement { pub expression: Expression, } +#[derive(Clone, Debug)] pub struct ParenthesizedExpression { pub expression: Expression, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct BinaryExpression { pub left: Expression, pub operator: OperatorKind, @@ -273,24 +310,29 @@ pub struct BinaryExpression { pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct ArrayLiteral { pub elements: Option>, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct BoolLiteral { pub value: bool, } +#[derive(Clone, Debug)] pub struct StringLiteral { pub value: String, } +#[derive(Clone, Debug)] pub struct NumberLiteral { pub value: String, pub type_info: TypeInfo, } +#[derive(Clone, Debug)] pub struct UnitLiteral {} pub struct SimpleType { From 6ab1a4732384fd2912a5eaf9983275cbdf3ecf9e Mon Sep 17 00:00:00 2001 From: Georgii Plotnikov Date: Tue, 16 Dec 2025 15:21:22 +0900 Subject: [PATCH 4/4] Refactor HIR module: enhance node handling and add function definition extraction --- core/ast/src/arena.rs | 8 + core/hir/src/hir.rs | 491 +++++++++++++++++++++++++++++- core/hir/src/lib.rs | 2 +- core/hir/src/nodes.rs | 2 + core/hir/src/nodes_impl.rs | 28 +- core/wasm-codegen/src/compiler.rs | 11 +- core/wasm-codegen/src/lib.rs | 18 +- 7 files changed, 533 insertions(+), 27 deletions(-) diff --git a/core/ast/src/arena.rs b/core/ast/src/arena.rs index 357bdcf3..586a7a8b 100644 --- a/core/ast/src/arena.rs +++ b/core/ast/src/arena.rs @@ -10,6 +10,11 @@ pub struct Arena { } impl Arena { + /// Adds a node to the arena with the specified parent. + /// + /// # Panics + /// + /// Panics if the node ID is zero or if a node with the same ID already exists in the arena. pub fn add_node(&mut self, node: AstNode, parent_id: u32) { // println!("Adding node with ID: {node:?}"); assert!(node.id() != 0, "Node ID must be non-zero"); @@ -37,10 +42,12 @@ impl Arena { self.node_routes.push(node); } + #[must_use] pub fn find_node(&self, id: u32) -> Option { self.nodes.get(&id).cloned() } + #[must_use] pub fn find_parent_node(&self, id: u32) -> Option { self.node_routes .iter() @@ -74,6 +81,7 @@ impl Arena { result } + #[must_use] pub fn list_type_definitions(&self) -> Vec> { self.list_nodes_cmp(|node| { if let AstNode::Definition(Definition::Type(type_def)) = node { diff --git a/core/hir/src/hir.rs b/core/hir/src/hir.rs index 2d6327c3..f76e0b51 100644 --- a/core/hir/src/hir.rs +++ b/core/hir/src/hir.rs @@ -1,5 +1,40 @@ -use crate::{arena::Arena, symbol_table::SymbolTable}; -use inference_ast::arena::Arena as AstArena; +use crate::{ + arena::Arena, + nodes::{ + Argument as HirArgument, ArgumentType as HirArgumentType, + ArrayIndexAccessExpression as HirArrayIndexAccessExpression, + ArrayLiteral as HirArrayLiteral, AssertStatement as HirAssertStatement, + AssignStatement as HirAssignStatement, BinaryExpression as HirBinaryExpression, + Block as HirBlock, BlockType as HirBlockType, BoolLiteral as HirBoolLiteral, + BreakStatement as HirBreakStatement, ConstantDefinition as HirConstantDefinition, + Definition as HirDefinition, EnumDefinition as HirEnumDefinition, + Expression as HirExpression, ExternalFunctionDefinition as HirExternalFunctionDefinition, + FunctionCallExpression as HirFunctionCallExpression, + FunctionDefinition as HirFunctionDefinition, IfStatement as HirIfStatement, + IgnoreArgument as HirIgnoreArgument, Literal as HirLiteral, + LoopStatement as HirLoopStatement, MemberAccessExpression as HirMemberAccessExpression, + NumberLiteral as HirNumberLiteral, OperatorKind as HirOperatorKind, + ParenthesizedExpression as HirParenthesizedExpression, + PrefixUnaryExpression as HirPrefixUnaryExpression, ReturnStatement as HirReturnStatement, + SelfReference as HirSelfReference, SourceFile as HirSourceFile, + SpecDefinition as HirSpecDefinition, Statement as HirStatement, + StringLiteral as HirStringLiteral, StructDefinition as HirStructDefinition, + StructExpression as HirStructExpression, StructField as HirStructField, + TypeDefinition as HirTypeDefinition, TypeDefinitionStatement as HirTypeDefinitionStatement, + TypeMemberAccessExpression as HirTypeMemberAccessExpression, + UnaryOperatorKind as HirUnaryOperator, UnitLiteral as HirUnitLiteral, + UzumakiExpression as HirUzumakiExpression, + VariableDefinitionStatement as HirVariableDefinitionStatement, + }, + symbol_table::{ScopeRef, SymbolTable}, + type_info::{NumberTypeKindNumberType, TypeInfo, TypeInfoKind}, +}; +use inference_ast::{ + arena::Arena as AstArena, + nodes as ast, + nodes::{Definition as AstDefinition, SourceFile as AstSourceFile}, +}; +use std::rc::Rc; #[derive(Clone, Default)] pub struct Hir { @@ -9,10 +44,456 @@ pub struct Hir { impl Hir { #[must_use] - pub fn new(arena: AstArena) -> Self { + pub fn new(ast_arena: &AstArena) -> Self { + let mut symbol_table = SymbolTable::new(); + // Create root scope + let root_scope = crate::symbol_table::Scope::new(0, "crate".to_string(), None); + symbol_table.insert_scope(root_scope.clone()); + + // Pass 1: Build Symbol Table + // ast_arena.sources is Vec. + // Process all definitions in all source files. + for file in &ast_arena.sources { + for def in &file.definitions { + crate::symbol_table::process_definition( + &root_scope, + def.clone(), + &mut symbol_table, + ); + } + } + + symbol_table.build_symbol_tables(); + + // Pass 2: Build HIR + let mut hir_sources = Vec::new(); + for file in &ast_arena.sources { + hir_sources.push(transform_source_file(file, &symbol_table, &root_scope)); + } + Self { - arena: Arena::default(), - symbol_table: SymbolTable::default(), + arena: Arena { + sources: hir_sources, + }, + symbol_table, + } + } +} + +fn transform_source_file( + file: &AstSourceFile, + table: &SymbolTable, + scope: &ScopeRef, +) -> HirSourceFile { + let definitions = file + .definitions + .iter() + .map(|d| transform_definition(d, table, scope)) + .collect(); + + HirSourceFile { + directives: vec![], + definitions, + } +} + +#[allow(clippy::too_many_lines)] +fn transform_definition( + def: &AstDefinition, + table: &SymbolTable, + scope: &ScopeRef, +) -> HirDefinition { + match def { + AstDefinition::Function(f) => { + let fn_scope = table + .scopes + .get(&f.id) + .cloned() + .unwrap_or_else(|| scope.clone()); + + let arguments = f.arguments.as_ref().map(|args| { + args.iter() + .map(|arg| match arg { + ast::ArgumentType::Argument(a) => { + HirArgumentType::Argument(Rc::new(HirArgument { + name: a.name.name.clone(), + is_mut: a.is_mut, + type_info: TypeInfo::new(&a.ty), + })) + } + ast::ArgumentType::SelfReference(s) => { + HirArgumentType::SelfReference(Rc::new(HirSelfReference { + is_mut: s.is_mut, + })) + } + ast::ArgumentType::IgnoreArgument(i) => { + HirArgumentType::IgnoreArgument(Rc::new(HirIgnoreArgument { + type_info: TypeInfo::new(&i.ty), + })) + } + ast::ArgumentType::Type(t) => HirArgumentType::TypeInfo(TypeInfo::new(t)), + }) + .collect() + }); + + let body = match &f.body { + ast::BlockType::Block(b) => { + HirBlockType::Block(Rc::new(transform_block(b, table, &fn_scope))) + } + ast::BlockType::Assume(b) => { + HirBlockType::Assume(Rc::new(transform_block(b, table, &fn_scope))) + } + ast::BlockType::Forall(b) => { + HirBlockType::Forall(Rc::new(transform_block(b, table, &fn_scope))) + } + ast::BlockType::Exists(b) => { + HirBlockType::Exists(Rc::new(transform_block(b, table, &fn_scope))) + } + ast::BlockType::Unique(b) => { + HirBlockType::Unique(Rc::new(transform_block(b, table, &fn_scope))) + } + }; + + HirDefinition::Function(Rc::new(HirFunctionDefinition { + name: f.name.name.clone(), + type_parameters: f + .type_parameters + .as_ref() + .map(|tp| tp.iter().map(|t| t.name.clone()).collect()), + arguments, + returns: f.returns.as_ref().map(TypeInfo::new), + body, + })) + } + ast::Definition::Struct(s) => HirDefinition::Struct(Rc::new(HirStructDefinition { + name: s.name.name.clone(), + fields: s + .fields + .iter() + .map(|f| { + Rc::new(HirStructField { + name: f.name.name.clone(), + type_info: TypeInfo::new(&f.type_), + }) + }) + .collect(), + methods: s + .methods + .iter() + .map(|m| { + if let HirDefinition::Function(f) = + transform_definition(&ast::Definition::Function(m.clone()), table, scope) + { + f + } else { + panic!("Expected Function definition"); + } + }) + .collect(), + })), + ast::Definition::Enum(e) => HirDefinition::Enum(Rc::new(HirEnumDefinition { + name: e.name.name.clone(), + variants: e.variants.iter().map(|v| v.name.clone()).collect(), + })), + ast::Definition::Constant(c) => HirDefinition::Constant(Rc::new(HirConstantDefinition { + name: c.name.name.clone(), + type_info: TypeInfo::new(&c.ty), + value: transform_literal(&c.value, table, scope), + })), + ast::Definition::Type(t) => HirDefinition::Type(Rc::new(HirTypeDefinition { + name: t.name.name.clone(), + type_info: TypeInfo::new(&t.ty), + })), + ast::Definition::Spec(s) => { + let spec_scope = table + .scopes + .get(&s.name.id) + .cloned() + .unwrap_or(scope.clone()); + HirDefinition::Spec(Rc::new(HirSpecDefinition { + name: s.name.name.clone(), + definitions: s + .definitions + .iter() + .map(|d| transform_definition(d, table, &spec_scope)) + .collect(), + })) + } + ast::Definition::ExternalFunction(f) => { + HirDefinition::ExternalFunction(Rc::new(HirExternalFunctionDefinition { + name: f.name.name.clone(), + arguments: f.arguments.as_ref().map(|args| { + args.iter() + .map(|_| HirArgumentType::TypeInfo(TypeInfo::default())) + .collect() + }), + returns: f.returns.as_ref().map(TypeInfo::new), + })) + } + } +} + +fn transform_block(block: &ast::Block, table: &SymbolTable, scope: &ScopeRef) -> HirBlock { + HirBlock { + statements: block + .statements + .iter() + .map(|s| transform_statement(s, table, scope)) + .collect(), + } +} + +fn transform_statement( + stmt: &ast::Statement, + table: &SymbolTable, + scope: &ScopeRef, +) -> HirStatement { + match stmt { + ast::Statement::VariableDefinition(v) => { + HirStatement::VariableDefinition(Rc::new(HirVariableDefinitionStatement { + name: v.name.name.clone(), + type_info: TypeInfo::new(&v.ty), + value: v + .value + .as_ref() + .map(|e| transform_expression(&e.borrow(), table, scope)), + is_uzumaki: v.is_uzumaki, + })) + } + ast::Statement::Expression(e) => { + HirStatement::Expression(transform_expression(e, table, scope)) + } + ast::Statement::Return(r) => HirStatement::Return(Rc::new(HirReturnStatement { + expression: transform_expression(&r.expression.borrow(), table, scope), + })), + ast::Statement::Assign(a) => HirStatement::Assign(Rc::new(HirAssignStatement { + left: transform_expression(&a.left.borrow(), table, scope), + right: transform_expression(&a.right.borrow(), table, scope), + })), + ast::Statement::If(i) => HirStatement::If(Rc::new(HirIfStatement { + condition: transform_expression(&i.condition.borrow(), table, scope), + if_arm: transform_block_type(&i.if_arm, table, scope), + else_arm: i + .else_arm + .as_ref() + .map(|b| transform_block_type(b, table, scope)), + })), + ast::Statement::Loop(l) => HirStatement::Loop(Rc::new(HirLoopStatement { + condition: l + .condition + .borrow() + .as_ref() + .map(|c| transform_expression(c, table, scope)), + body: transform_block_type(&l.body, table, scope), + })), + ast::Statement::Break(_) => HirStatement::Break(Rc::new(HirBreakStatement {})), + ast::Statement::Block(b) => HirStatement::Block(transform_block_type(b, table, scope)), + ast::Statement::Assert(a) => HirStatement::Assert(Rc::new(HirAssertStatement { + expression: transform_expression(&a.expression.borrow(), table, scope), + })), + ast::Statement::ConstantDefinition(c) => { + HirStatement::ConstantDefinition(Rc::new(HirConstantDefinition { + name: c.name.name.clone(), + type_info: TypeInfo::new(&c.ty), + value: transform_literal(&c.value, table, scope), + })) } + ast::Statement::TypeDefinition(t) => { + HirStatement::TypeDefinition(Rc::new(HirTypeDefinitionStatement { + name: t.name.name.clone(), + type_info: TypeInfo::new(&t.ty), + })) + } + } +} + +fn transform_block_type( + bt: &ast::BlockType, + table: &SymbolTable, + scope: &ScopeRef, +) -> HirBlockType { + match bt { + ast::BlockType::Block(b) => HirBlockType::Block(Rc::new(transform_block(b, table, scope))), + ast::BlockType::Assume(b) => { + HirBlockType::Assume(Rc::new(transform_block(b, table, scope))) + } + ast::BlockType::Forall(b) => { + HirBlockType::Forall(Rc::new(transform_block(b, table, scope))) + } + ast::BlockType::Exists(b) => { + HirBlockType::Exists(Rc::new(transform_block(b, table, scope))) + } + ast::BlockType::Unique(b) => { + HirBlockType::Unique(Rc::new(transform_block(b, table, scope))) + } + } +} + +fn transform_expression( + expr: &ast::Expression, + table: &SymbolTable, + scope: &ScopeRef, +) -> HirExpression { + let type_info = table.infer_expr_type(scope.borrow().id, expr); + + match expr { + ast::Expression::Binary(b) => { + // b is Rc + HirExpression::Binary(Rc::new(HirBinaryExpression { + left: transform_expression(&b.left.borrow(), table, scope), + operator: map_operator(&b.operator), + right: transform_expression(&b.right.borrow(), table, scope), + type_info, + })) + } + ast::Expression::Literal(l) => HirExpression::Literal(transform_literal(l, table, scope)), + ast::Expression::FunctionCall(fc) => { + let name = if let ast::Expression::Identifier(id) = &fc.function { + id.name.clone() + } else { + String::new() + }; + + HirExpression::FunctionCall(Rc::new(HirFunctionCallExpression { + name, + function: transform_expression(&fc.function, table, scope), + type_parameters: fc + .type_parameters + .as_ref() + .map(|tp| tp.iter().map(|t| t.name.clone()).collect()), + arguments: fc.arguments.as_ref().map(|args| { + args.iter() + .map(|(n, e)| { + ( + n.as_ref().map(|x| x.name.clone()), + transform_expression(&e.borrow(), table, scope), + ) + }) + .collect() + }), + type_info, + })) + } + ast::Expression::Struct(s) => HirExpression::Struct(Rc::new(HirStructExpression { + name: s.name.name.clone(), + fields: s.fields.as_ref().map(|f| { + f.iter() + .map(|(n, e)| { + ( + n.name.clone(), + transform_expression(&e.borrow(), table, scope), + ) + }) + .collect() + }), + type_info, + })), + ast::Expression::MemberAccess(ma) => { + HirExpression::MemberAccess(Rc::new(HirMemberAccessExpression { + expression: transform_expression(&ma.expression.borrow(), table, scope), + name: ma.name.name.clone(), + type_info, + })) + } + ast::Expression::ArrayIndexAccess(a) => { + HirExpression::ArrayIndexAccess(Rc::new(HirArrayIndexAccessExpression { + array: transform_expression(&a.array.borrow(), table, scope), + index: transform_expression(&a.index.borrow(), table, scope), + type_info, + })) + } + ast::Expression::TypeMemberAccess(t) => { + HirExpression::TypeMemberAccess(Rc::new(HirTypeMemberAccessExpression { + expression: transform_expression(&t.expression.borrow(), table, scope), + name: t.name.name.clone(), + type_info, + })) + } + ast::Expression::Parenthesized(p) => { + HirExpression::Parenthesized(Rc::new(HirParenthesizedExpression { + expression: transform_expression(&p.expression.borrow(), table, scope), + type_info, + })) + } + ast::Expression::PrefixUnary(u) => { + HirExpression::PrefixUnary(Rc::new(HirPrefixUnaryExpression { + expression: transform_expression(&u.expression.borrow(), table, scope), + operator: match u.operator { + ast::UnaryOperatorKind::Neg => HirUnaryOperator::Neg, + }, + type_info, + })) + } + ast::Expression::Uzumaki(_) => { + HirExpression::Uzumaki(Rc::new(HirUzumakiExpression { type_info })) + } + _ => HirExpression::TypeInfo(type_info), + } +} + +fn transform_literal(lit: &ast::Literal, table: &SymbolTable, scope: &ScopeRef) -> HirLiteral { + match lit { + ast::Literal::Bool(b) => HirLiteral::Bool(Rc::new(HirBoolLiteral { value: b.value })), + ast::Literal::String(s) => HirLiteral::String(Rc::new(HirStringLiteral { + value: s.value.clone(), + })), + ast::Literal::Number(n) => HirLiteral::Number(Rc::new(HirNumberLiteral { + value: n.value.clone(), + type_info: TypeInfo { + kind: TypeInfoKind::Number(NumberTypeKindNumberType::I32), + type_params: vec![], + }, + })), + ast::Literal::Unit(_) => HirLiteral::Unit(Rc::new(HirUnitLiteral {})), + ast::Literal::Array(a) => { + let elements = a.elements.as_ref().map(|el| { + el.iter() + .map(|e| transform_expression(&e.borrow(), table, scope)) + .collect::>() + }); + let type_info = if let Some(elems) = &elements { + if let Some(first) = elems.first() { + let inner_type = first.type_info(); + let array_len = u32::try_from(elems.len()).ok(); + TypeInfo { + kind: TypeInfoKind::Array(Box::new(inner_type), array_len), + type_params: vec![], + } + } else { + TypeInfo::default() + } + } else { + TypeInfo::default() + }; + HirLiteral::Array(Rc::new(HirArrayLiteral { + elements, + type_info, + })) + } + } +} + +fn map_operator(op: &ast::OperatorKind) -> HirOperatorKind { + match op { + ast::OperatorKind::Pow => HirOperatorKind::Pow, + ast::OperatorKind::Add => HirOperatorKind::Add, + ast::OperatorKind::Sub => HirOperatorKind::Sub, + ast::OperatorKind::Mul => HirOperatorKind::Mul, + ast::OperatorKind::Div => HirOperatorKind::Div, + ast::OperatorKind::Mod => HirOperatorKind::Mod, + ast::OperatorKind::And => HirOperatorKind::And, + ast::OperatorKind::Or => HirOperatorKind::Or, + ast::OperatorKind::Eq => HirOperatorKind::Eq, + ast::OperatorKind::Ne => HirOperatorKind::Ne, + ast::OperatorKind::Lt => HirOperatorKind::Lt, + ast::OperatorKind::Le => HirOperatorKind::Le, + ast::OperatorKind::Gt => HirOperatorKind::Gt, + ast::OperatorKind::Ge => HirOperatorKind::Ge, + ast::OperatorKind::BitAnd => HirOperatorKind::BitAnd, + ast::OperatorKind::BitOr => HirOperatorKind::BitOr, + ast::OperatorKind::BitXor => HirOperatorKind::BitXor, + ast::OperatorKind::BitNot => HirOperatorKind::BitNot, + ast::OperatorKind::Shl => HirOperatorKind::Shl, + ast::OperatorKind::Shr => HirOperatorKind::Shr, } } diff --git a/core/hir/src/lib.rs b/core/hir/src/lib.rs index 9cdff681..c860beaf 100644 --- a/core/hir/src/lib.rs +++ b/core/hir/src/lib.rs @@ -1,6 +1,6 @@ #![warn(clippy::pedantic)] pub mod hir; -mod nodes; +pub mod nodes; mod nodes_impl; mod symbol_table; // mod type_inference; diff --git a/core/hir/src/nodes.rs b/core/hir/src/nodes.rs index 85fa01a4..4a8aaed6 100644 --- a/core/hir/src/nodes.rs +++ b/core/hir/src/nodes.rs @@ -168,6 +168,7 @@ pub struct ExternalFunctionDefinition { pub returns: Option, } +//TODO delete me #[derive(Clone, Debug)] pub struct TypeDefinition { pub name: String, @@ -230,6 +231,7 @@ pub struct VariableDefinitionStatement { pub is_uzumaki: bool, } +//TODO delete me #[derive(Clone, Debug)] pub struct TypeDefinitionStatement { pub name: String, diff --git a/core/hir/src/nodes_impl.rs b/core/hir/src/nodes_impl.rs index f72169c3..83477a74 100644 --- a/core/hir/src/nodes_impl.rs +++ b/core/hir/src/nodes_impl.rs @@ -1,5 +1,7 @@ +use std::rc::Rc; + use crate::{ - nodes::{Expression, Literal, UzumakiExpression}, + nodes::{Definition, Expression, FunctionDefinition, Literal, SourceFile, UzumakiExpression}, type_info::{NumberTypeKindNumberType, TypeInfo, TypeInfoKind}, }; @@ -47,16 +49,32 @@ impl Literal { impl UzumakiExpression { #[must_use] pub fn is_i32(&self) -> bool { - return matches!( + matches!( self.type_info.kind, TypeInfoKind::Number(NumberTypeKindNumberType::I32) - ); + ) } #[must_use] pub fn is_i64(&self) -> bool { - return matches!( + matches!( self.type_info.kind, TypeInfoKind::Number(NumberTypeKindNumberType::I64) - ); + ) + } +} + +impl SourceFile { + #[must_use] + pub fn function_definitions(&self) -> Vec> { + self.definitions + .iter() + .filter_map(|item| { + if let Definition::Function(func_def) = item { + Some(func_def.clone()) + } else { + None + } + }) + .collect() } } diff --git a/core/wasm-codegen/src/compiler.rs b/core/wasm-codegen/src/compiler.rs index 034c0ce1..117b813e 100644 --- a/core/wasm-codegen/src/compiler.rs +++ b/core/wasm-codegen/src/compiler.rs @@ -2,7 +2,7 @@ #![allow(dead_code)] use crate::utils; use inference_hir::{ - nodes::{BlockType, Expression, FunctionDefinition, Literal, Statement, Type}, + nodes::{BlockType, Expression, FunctionDefinition, Literal, Statement}, type_info::{NumberTypeKindNumberType, TypeInfoKind}, }; use inkwell::{ @@ -58,9 +58,9 @@ impl<'ctx> Compiler<'ctx> { } pub(crate) fn visit_function_definition(&self, function_definition: &Rc) { - let fn_name = function_definition.name(); + let fn_name = function_definition.name; let fn_type = match &function_definition.returns { - Some(ret_type) => match ret_type { + match &ret_type.kind { Type::Array(_array_type) => todo!(), Type::Simple(simple_type) => match simple_type.name.to_lowercase().as_str() { "i32" => self.context.i32_type().fn_type(&[], false), @@ -74,8 +74,7 @@ impl<'ctx> Compiler<'ctx> { Type::QualifiedName(_qualified_name) => todo!(), Type::Qualified(_type_qualified_name) => todo!(), Type::Custom(_identifier) => todo!(), - }, - None => self.context.void_type().fn_type(&[], false), + } }; let function = self.module.add_function(fn_name.as_str(), fn_type, None); @@ -206,7 +205,7 @@ impl<'ctx> Compiler<'ctx> { } Statement::TypeDefinition(_type_definition_statement) => todo!(), Statement::Assert(_assert_statement) => todo!(), - Statement::ConstantDefinition(constant_definition) => match &constant_definition.ty { + Statement::ConstantDefinition(constant_definition) => match &constant_definition.type_info { Type::Array(_type_array) => todo!(), Type::Simple(simple_type) => { match &simple_type diff --git a/core/wasm-codegen/src/lib.rs b/core/wasm-codegen/src/lib.rs index 7e931603..b227f616 100644 --- a/core/wasm-codegen/src/lib.rs +++ b/core/wasm-codegen/src/lib.rs @@ -1,13 +1,12 @@ #![warn(clippy::pedantic)] -use inference_ast::t_ast::TypedAst; +use crate::compiler::Compiler; +use inference_hir::hir::Hir; use inkwell::{ context::Context, targets::{InitializationConfig, Target}, }; -use crate::compiler::Compiler; - mod compiler; mod utils; @@ -19,26 +18,25 @@ mod utils; /// support is not yet implemented. /// /// Returns an error if code generation fails. -pub fn codegen(t_ast: &TypedAst) -> anyhow::Result> { +pub fn codegen(hir: &Hir) -> anyhow::Result> { Target::initialize_webassembly(&InitializationConfig::default()); let context = Context::create(); let compiler = Compiler::new(&context, "wasm_module"); - if t_ast.source_files.is_empty() { + if hir.arena.sources.is_empty() { return compiler.compile_to_wasm("output.wasm", 3); } - if t_ast.source_files.len() > 1 { + if hir.arena.sources.len() > 1 { todo!("Multi-file support not yet implemented"); } - traverse_t_ast_with_compiler(t_ast, &compiler); - + traverse_hir_with_compiler(hir, &compiler); let wasm_bytes = compiler.compile_to_wasm("output.wasm", 3)?; Ok(wasm_bytes) } -fn traverse_t_ast_with_compiler(t_ast: &TypedAst, compiler: &Compiler) { - for source_file in &t_ast.source_files { +fn traverse_hir_with_compiler(hir: &Hir, compiler: &Compiler) { + for source_file in &hir.arena.sources { for func_def in source_file.function_definitions() { compiler.visit_function_definition(&func_def); }