From 175cec83edaf241c2982dd70ae6b3f1b36610461 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 29 Aug 2025 19:52:11 -0400 Subject: [PATCH 01/16] feat: initial implementation of array node --- src/ArrayNode.jl | 511 ++++++++++++++++++++++++++++++++++++++ src/DynamicExpressions.jl | 2 + test/test_array_node.jl | 172 +++++++++++++ 3 files changed, 685 insertions(+) create mode 100644 src/ArrayNode.jl create mode 100644 test/test_array_node.jl diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl new file mode 100644 index 00000000..6fc1d7d5 --- /dev/null +++ b/src/ArrayNode.jl @@ -0,0 +1,511 @@ +module ArrayNodeModule + +using ..NodeModule: AbstractExpressionNode, Nullable +using ..UtilsModule: Undefined + +import ..NodeModule: + constructorof, with_type_parameters, with_max_degree, + preserve_sharing, max_degree, default_allocator, + get_children, set_children!, unsafe_get_children, + tree_mapreduce, count_nodes, set_node!, any, copy_node + +import Base: copy, hash, ==, getproperty, setproperty!, eltype + +export ArrayNode + +# Helper function to create array of specific type +@inline function create_array(::Type{A}, ::Type{T}, n::Int) where {A,T} + return A{T}(undef, n) +end + +# Container for all nodes in a tree +# A is the vector array type (e.g., Vector, FixedSizeVector) +mutable struct ArrayTree{T,D,A<:AbstractVector} + degrees::A # Vector of UInt8 + constants::A # Vector of Bool + vals::A # Vector of T + features::A # Vector of UInt16 + ops::A # Vector of UInt8 + children::A # Vector of NTuple{D,Int8} + + root_idx::Int8 + n_nodes::Int8 + free_list::A # Vector of Int8 + free_count::Int8 + + function ArrayTree{T,D,A}(n::Int) where {T,D,A<:AbstractVector} + tree = new{T,D,A}( + create_array(A, UInt8, n), + create_array(A, Bool, n), + create_array(A, T, n), + create_array(A, UInt16, n), + create_array(A, UInt8, n), + create_array(A, NTuple{D,Int8}, n), + Int8(0), Int8(0), + create_array(A, Int8, n), + Int8(n) + ) + # Initialize free list and children + for i in 1:n + tree.free_list[i] = Int8(i) + tree.children[i] = ntuple(_ -> Int8(-1), Val(D)) + end + return tree + end +end + +# Default constructor using regular arrays +ArrayTree{T,D}(n::Int) where {T,D} = ArrayTree{T,D,Vector}(n) + +# ArrayNode is just a lightweight view into the ArrayTree +mutable struct ArrayNode{T,D,A<:AbstractVector} <: AbstractExpressionNode{T,D} + tree::ArrayTree{T,D,A} + idx::Int8 +end + +# The clever part: getproperty just indexes into the arrays! +function getproperty(n::ArrayNode, k::Symbol) + tree = getfield(n, :tree) + idx = getfield(n, :idx) + + if k == :tree + return tree + elseif k == :idx + return idx + elseif k == :degree + return @inbounds tree.degrees[idx] + elseif k == :constant + return @inbounds tree.constants[idx] + elseif k == :val + return @inbounds tree.vals[idx] + elseif k == :feature + return @inbounds tree.features[idx] + elseif k == :op + return @inbounds tree.ops[idx] + elseif k == :children + # Return tuple of child ArrayNodes wrapped in Nullable + D = max_degree(typeof(n)) + return ntuple(i -> begin + child_idx = @inbounds tree.children[idx][i] + if child_idx < 0 + Nullable(true, n) # Poison node + else + Nullable(false, ArrayNode(tree, child_idx)) + end + end, Val(D)) + elseif k == :l # Left child for compatibility + child_idx = @inbounds tree.children[idx][1] + return child_idx < 0 ? error("No left child") : ArrayNode(tree, child_idx) + elseif k == :r # Right child for compatibility + child_idx = @inbounds tree.children[idx][2] + return child_idx < 0 ? error("No right child") : ArrayNode(tree, child_idx) + else + error("Unknown field $k") + end +end + +function setproperty!(n::ArrayNode, k::Symbol, v) + tree = getfield(n, :tree) + idx = getfield(n, :idx) + + if k == :degree + @inbounds tree.degrees[idx] = v + elseif k == :constant + @inbounds tree.constants[idx] = v + elseif k == :val + @inbounds tree.vals[idx] = v + elseif k == :feature + @inbounds tree.features[idx] = v + elseif k == :op + @inbounds tree.ops[idx] = v + elseif k == :l + if isa(v, ArrayNode) + children = tree.children[idx] + @inbounds tree.children[idx] = (getfield(v, :idx), children[2:end]...) + end + elseif k == :r + if isa(v, ArrayNode) + children = tree.children[idx] + @inbounds tree.children[idx] = (children[1], getfield(v, :idx), children[3:end]...) + end + else + error("Cannot set field $k") + end + return v +end + +# Allocation management +function allocate_node!(tree::ArrayTree) + tree.free_count == 0 && error("ArrayTree full") + idx = tree.free_list[tree.free_count] + tree.free_count -= 1 + tree.n_nodes += 1 + return idx +end + +function free_node!(tree::ArrayTree, idx::Int8) + tree.free_count += 1 + tree.free_list[tree.free_count] = idx + tree.n_nodes -= 1 +end + +# Default constructors - now include array type parameters +ArrayNode{T,D,A}(n::Int=32) where {T,D,A} = ArrayNode{T,D,A}(Undefined; allocator=ArrayTree{T,D,A}(n)) +ArrayNode{T,D}(n::Int) where {T,D} = ArrayNode{T,D,Vector}(n) +ArrayNode{T}(n::Int) where {T} = ArrayNode{T,2,Vector}(n) + +# Keyword constructors for partial type signatures +ArrayNode{T,D}(; kwargs...) where {T,D} = ArrayNode{T,D,Vector}(Undefined; kwargs...) +ArrayNode{T}(; kwargs...) where {T} = ArrayNode{T,2,Vector}(Undefined; kwargs...) + +# Constructor with keyword arguments - matches Node interface +function ArrayNode{T,D,A}( + ::Type{T1}=Undefined; + val=nothing, + feature=nothing, + op=nothing, + l=nothing, + r=nothing, + children=nothing, + allocator=nothing +) where {T,D,A<:AbstractVector,T1} + # Determine tree source + # Always create a new tree unless an allocator is explicitly provided + tree = if !isnothing(allocator) && isa(allocator, ArrayTree) + allocator + else + # Just use a reasonable default size + ArrayTree{T,D,A}(64) + end + + idx = allocate_node!(tree) + # Only set root_idx if this tree is new (no nodes allocated yet except this one) + if tree.n_nodes == 1 + tree.root_idx = idx + end + + if !isnothing(val) + tree.degrees[idx] = 0 + tree.constants[idx] = true + tree.vals[idx] = val + return ArrayNode(tree, idx) + end + + if !isnothing(feature) + tree.degrees[idx] = 0 + tree.constants[idx] = false + tree.features[idx] = feature + return ArrayNode(tree, idx) + end + + if !isnothing(op) + _children = if !isnothing(l) && isnothing(r) + (l,) + elseif !isnothing(l) && !isnothing(r) + (l, r) + else + children + end + + if !isnothing(_children) + degree = length(_children) + tree.degrees[idx] = degree + tree.ops[idx] = op + + # Copy children into this tree + child_indices = ntuple(i -> begin + if i <= length(_children) + child = _children[i] + if isa(child, ArrayNode) + child_tree = getfield(child, :tree) + child_idx = getfield(child, :idx) + if child_tree === tree + # Same tree - just link + child_idx + else + # Different tree - copy + copy_subtree!(tree, child_tree, child_idx) + end + else + Int8(-1) + end + else + Int8(-1) + end + end, Val(D)) + tree.children[idx] = child_indices + + return ArrayNode(tree, idx) + end + end + + # Default: empty constant + tree.degrees[idx] = 0 + tree.constants[idx] = true + tree.vals[idx] = zero(T) + return ArrayNode(tree, idx) +end + +function copy_subtree!(dst::ArrayTree{T,D,A}, src::ArrayTree{T,D,A}, src_idx::Int8) where {T,D,A} + dst_idx = allocate_node!(dst) + + @inbounds begin + dst.degrees[dst_idx] = src.degrees[src_idx] + dst.constants[dst_idx] = src.constants[src_idx] + dst.vals[dst_idx] = src.vals[src_idx] + dst.features[dst_idx] = src.features[src_idx] + dst.ops[dst_idx] = src.ops[src_idx] + end + + degree = @inbounds src.degrees[src_idx] + child_indices = ntuple(i -> begin + if i <= degree + child_idx = @inbounds src.children[src_idx][i] + if child_idx >= 0 + copy_subtree!(dst, src, child_idx) + else + Int8(-1) + end + else + Int8(-1) + end + end, Val(D)) + dst.children[dst_idx] = child_indices + + return dst_idx +end + +# Core interface implementations +eltype(::Type{<:ArrayNode{T}}) where {T} = T +eltype(::ArrayNode{T}) where {T} = T + +max_degree(::Type{<:ArrayNode}) = 2 +max_degree(::Type{<:ArrayNode{T,D}}) where {T,D} = D +max_degree(n::ArrayNode) = max_degree(typeof(n)) + +preserve_sharing(::Type{<:ArrayNode}) = false + +constructorof(::Type{<:ArrayNode}) = ArrayNode +with_type_parameters(::Type{<:ArrayNode}, ::Type{T}) where {T} = ArrayNode{T,2,Vector} +with_max_degree(::Type{<:ArrayNode{T,D,A}}, ::Val{D2}) where {T,D,A,D2} = ArrayNode{T,D2,A} +default_allocator(::Type{ArrayNode{T,D,A}}) where {T,D,A} = ArrayTree{T,D,A}(32) + +# get_children and set_children! +function unsafe_get_children(n::ArrayNode{T,D}) where {T,D} + tree = getfield(n, :tree) + idx = getfield(n, :idx) + return ntuple(i -> begin + child_idx = @inbounds tree.children[idx][i] + if child_idx < 0 + Nullable(true, n) + else + Nullable(false, ArrayNode(tree, child_idx)) + end + end, Val(D)) +end + +function get_children(n::ArrayNode{T,D}, ::Val{d}) where {T,D,d} + tree = getfield(n, :tree) + idx = getfield(n, :idx) + return ntuple(i -> begin + child_idx = @inbounds tree.children[idx][i] + ArrayNode(tree, child_idx) + end, Val(Int(d))) +end + +get_children(n::ArrayNode, d::Integer) = get_children(n, Val(d)) + +function set_children!(n::ArrayNode{T,D,A}, cs::Tuple) where {T,D,A} + tree = getfield(n, :tree) + idx = getfield(n, :idx) + child_indices = ntuple(i -> begin + if i <= length(cs) + child = cs[i] + if isa(child, ArrayNode) + getfield(child, :idx) + else + Int8(-1) + end + else + Int8(-1) + end + end, Val(D)) + tree.children[idx] = child_indices +end + +# Copy +function copy_node(n::ArrayNode{T,D,A}; break_sharing::Val{BS}=Val(false)) where {T,D,A,BS} + tree = getfield(n, :tree) + idx = getfield(n, :idx) + + # Count nodes to determine tree size needed + node_count = count_subtree(tree, idx) + # Add some buffer space + tree_size = max(32, node_count * 2) + + # Create new tree for the copy + new_tree = ArrayTree{T,D,A}(tree_size) + new_idx = copy_subtree!(new_tree, tree, idx) + new_tree.root_idx = new_idx + + return ArrayNode(new_tree, new_idx) +end + +copy(n::ArrayNode) = copy_node(n) + +# count_nodes +function count_nodes(n::ArrayNode) + tree = getfield(n, :tree) + return count_subtree(tree, getfield(n, :idx)) +end + +function count_subtree(tree::ArrayTree, idx::Int8) + count = 1 + degree = @inbounds tree.degrees[idx] + for i in 1:degree + child_idx = @inbounds tree.children[idx][i] + if child_idx >= 0 + count += count_subtree(tree, child_idx) + end + end + return count +end + +# Equality and hash +function ==(a::ArrayNode, b::ArrayNode) + a.degree != b.degree && return false + + if a.degree == 0 + a.constant != b.constant && return false + if a.constant + return a.val == b.val + else + return a.feature == b.feature + end + else + a.op != b.op && return false + + # Compare children recursively + for i in 1:a.degree + ca = get_children(a, Val(Int(a.degree)))[i] + cb = get_children(b, Val(Int(b.degree)))[i] + ca != cb && return false + end + return true + end +end + +function hash(n::ArrayNode, h::UInt=zero(UInt)) + if n.degree == 0 + if n.constant + return hash((0, n.val), h) + else + return hash((1, n.feature), h) + end + else + children_hashes = ntuple(i -> begin + child = get_children(n, Val(Int(n.degree)))[i] + hash(child, h) + end, Val(Int(n.degree))) + return hash((n.degree + 1, n.op, children_hashes), h) + end +end + +# set_node! implementation +function set_node!(dst::ArrayNode, src::ArrayNode) + dst_tree = getfield(dst, :tree) + src_tree = getfield(src, :tree) + dst_idx = getfield(dst, :idx) + src_idx = getfield(src, :idx) + + dst.degree = src.degree + + if src.degree == 0 + dst.constant = src.constant + if src.constant + dst.val = src.val + else + dst.feature = src.feature + end + else + dst.op = src.op + + # Copy children - need to get D from somewhere + # Since both dst and src are ArrayNodes, we can get it from the type + D = max_degree(typeof(dst)) + child_indices = ntuple(i -> begin + if i <= src.degree + child_idx = @inbounds src_tree.children[src_idx][i] + if child_idx >= 0 + if dst_tree === src_tree + # Same tree + child_idx + else + # Different tree - need to copy + copy_subtree!(dst_tree, src_tree, child_idx) + end + else + Int8(-1) + end + else + Int8(-1) + end + end, Val(D)) + dst_tree.children[dst_idx] = child_indices + end + + return nothing +end + +# tree_mapreduce and any +function tree_mapreduce(f::F, op::G, n::ArrayNode, ::Type{RT}=Any; f_on_shared=nothing, break_sharing=Val(false), kwargs...) where {F<:Function,G<:Function,RT} + tree = getfield(n, :tree) + return mapreduce_impl(f, op, tree, getfield(n, :idx)) +end + +function mapreduce_impl(f::F, op::G, tree::ArrayTree, idx::Int8) where {F,G} + degree = @inbounds tree.degrees[idx] + node = ArrayNode(tree, idx) + result = f(node) + + if degree > 0 + child_results = ntuple(i -> begin + child_idx = @inbounds tree.children[idx][i] + if child_idx >= 0 + mapreduce_impl(f, op, tree, child_idx) + else + nothing + end + end, Val(Int(degree))) + + # Filter out nothings and apply op + valid_results = filter(x -> !isnothing(x), child_results) + if !isempty(valid_results) + return op(result, valid_results...) + end + end + + return result +end + +function any(f::F, n::ArrayNode) where F<:Function + tree = getfield(n, :tree) + return any_impl(f, tree, getfield(n, :idx)) +end + +function any_impl(f::F, tree::ArrayTree, idx::Int8) where F + node = ArrayNode(tree, idx) + f(node) && return true + + degree = @inbounds tree.degrees[idx] + for i in 1:degree + child_idx = @inbounds tree.children[idx][i] + if child_idx >= 0 && any_impl(f, tree, child_idx) + return true + end + end + + return false +end + +end # module \ No newline at end of file diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 355e7b98..bc1a1b77 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -8,6 +8,7 @@ using DispatchDoctor: @stable, @unstable include("ExtensionInterface.jl") include("OperatorEnum.jl") include("Node.jl") + include("ArrayNode.jl") include("NodeUtils.jl") include("NodePreallocation.jl") include("Strings.jl") @@ -50,6 +51,7 @@ import .ValueInterfaceModule: tree_mapreduce, filter_map, filter_map! +import .ArrayNodeModule: ArrayNode import .NodePreallocationModule: allocate_container, copy_into! import .NodeModule: constructorof, diff --git a/test/test_array_node.jl b/test/test_array_node.jl new file mode 100644 index 00000000..4f297655 --- /dev/null +++ b/test/test_array_node.jl @@ -0,0 +1,172 @@ +@testitem "ArrayNode interface with Vector" begin + using DynamicExpressions + using DynamicExpressions: NodeInterface, ArrayNode + using Interfaces: Interfaces + + # Test with regular Vector + x1 = ArrayNode{Float64,2,Vector}(; feature=1) + x2 = ArrayNode{Float64,2,Vector}(; feature=2) + + operators = OperatorEnum(; binary_operators=[+, *], unary_operators=[sin]) + + # Create test trees matching the pattern in test_node_interface.jl + tree_branch_deg2 = ArrayNode{Float64,2,Vector}(; op=1, + l=x1, + r=ArrayNode{Float64,2,Vector}(; op=1, + l=ArrayNode{Float64,2,Vector}(; op=2, + l=x2, + r=ArrayNode{Float64,2,Vector}(; val=3.5) + ) + ) + ) # x1 + sin(x2 * 3.5) + + tree_branch_deg1 = ArrayNode{Float64,2,Vector}(; op=1, l=x1) # sin(x1) + tree_leaf_feature = x1 + tree_leaf_constant = ArrayNode{Float64,2,Vector}(; val=1.0) + + @test Interfaces.test( + NodeInterface, + ArrayNode, + [tree_branch_deg2, tree_branch_deg1, tree_leaf_feature, tree_leaf_constant], + ) +end + +@testitem "ArrayNode with custom array type" begin + using DynamicExpressions + using DynamicExpressions: NodeInterface, ArrayNode + using Interfaces: Interfaces + + # Test that ArrayNode works with any AbstractVector type + # For production use with FixedSizeArrays, you'd need a wrapper + # that handles mixed element types properly + + x1 = ArrayNode{Float64,2,Vector}(; feature=1) + x2 = ArrayNode{Float64,2,Vector}(; feature=2) + + operators = OperatorEnum(; binary_operators=[+, *], unary_operators=[sin]) + + # Create test trees + tree_branch_deg2 = ArrayNode{Float64,2,Vector}(; op=1, + l=x1, + r=ArrayNode{Float64,2,Vector}(; op=1, + l=ArrayNode{Float64,2,Vector}(; op=2, + l=x2, + r=ArrayNode{Float64,2,Vector}(; val=3.5) + ) + ) + ) + + tree_branch_deg1 = ArrayNode{Float64,2,Vector}(; op=1, l=x1) + tree_leaf_feature = x1 + tree_leaf_constant = ArrayNode{Float64,2,Vector}(; val=1.0) + + @test Interfaces.test( + NodeInterface, + ArrayNode, + [tree_branch_deg2, tree_branch_deg1, tree_leaf_feature, tree_leaf_constant], + ) +end + +@testitem "ArrayNode interface on n-arity nodes" begin + using DynamicExpressions + using DynamicExpressions: NodeInterface, ArrayNode + using Interfaces: Interfaces + + for D in (3, 4, 5) + # Test with regular arrays + x = [ArrayNode{Float64,D,Vector}(; feature=i) for i in 1:3] + operator_tuple = ((sin, cos, exp), (+, *, /, -), (fma, clamp), (max, min), ()) + # Create pairs for degrees 1 through D + pairs = [i => operator_tuple[i] for i in 1:D if !isempty(operator_tuple[i])] + operators = + isempty(pairs) ? OperatorEnum(1 => ()) : OperatorEnum(pairs[1], pairs[2:end]...) + DynamicExpressions.OperatorEnumConstructionModule.empty_all_globals!() + + let tree = ArrayNode{Float64,D,Vector}(; op=2, children=(x[1], x[2])) # * + if D > 2 + fma_idx = 1 + tree = ArrayNode{Float64,D,Vector}(; op=fma_idx, children=(tree, x[1], x[2])) # fma + end + if D > 3 + idx_max = 1 + tree = ArrayNode{Float64,D,Vector}(; op=idx_max, children=(tree, x[1], x[2], x[3])) # max + end + @test Interfaces.test(NodeInterface, ArrayNode, [tree]) + end + end +end + +@testitem "ArrayNode basic operations" begin + using DynamicExpressions + using DynamicExpressions: ArrayNode, OperatorEnum + + # Test with regular arrays (default) + x1 = ArrayNode{Float64,2}(; feature=1) + x2 = ArrayNode{Float64,2}(; feature=2) + c = ArrayNode{Float64,2}(; val=3.5) + + # Test basic properties + @test x1.degree == 0 + @test x1.feature == 1 + @test !x1.constant + + @test c.degree == 0 + @test c.val == 3.5 + @test c.constant + + # Test tree construction + mul = ArrayNode{Float64,2}(; op=3, l=x2, r=c) + @test mul.degree == 2 + @test mul.op == 3 + + sin_expr = ArrayNode{Float64,2}(; op=1, l=mul) + @test sin_expr.degree == 1 + + tree = ArrayNode{Float64,2}(; op=1, l=x1, r=sin_expr) + @test tree.degree == 2 + + # Test copy + tree_copy = copy(tree) + @test tree == tree_copy + @test tree !== tree_copy + + # Test hash + @test hash(tree) == hash(tree_copy) + + # Test count_nodes + @test count_nodes(tree) == 6 # tree, x1, sin_expr, mul, x2, c + + # Test string conversion + operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]) + str = string_tree(tree, operators) + @test str == "x1 + sin(x2 * 3.5)" + + # Test evaluation + X = [1.0 2.0; 0.5 1.0] # 2 features, 2 samples + result = eval_tree_array(tree, X, operators) + expected = X[1, :] .+ sin.(X[2, :] .* 3.5) + @test all(abs.(result[1] .- expected) .< 1e-10) +end + +@testitem "ArrayNode with Expressions" begin + using DynamicExpressions + using DynamicExpressions: ArrayNode, Expression + + # Create a simple tree with default arrays + x1 = ArrayNode{Float64,2}(; feature=1) + c = ArrayNode{Float64,2}(; val=2.0) + tree = ArrayNode{Float64,2}(; op=1, l=x1, r=c) + + operators = OperatorEnum(; binary_operators=[+, -, *, /]) + + # Test Expression conversion + expr = Expression(tree; operators=operators, variable_names=["x1", "x2"]) + @test string(expr) == "x1 + 2.0" + + # Test evaluation through Expression + X = [1.0 2.0 3.0] # 1 feature, 3 samples + result = expr(X) + expected = vec(X .+ 2.0) # Convert to vector to match result shape + @test all(abs.(result .- expected) .< 1e-10) +end + From 797d7a628f3c53b7460c0fff2f545054c52be2bb Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 29 Aug 2025 20:04:22 -0400 Subject: [PATCH 02/16] refactor: clean up implementation --- src/ArrayNode.jl | 6 ------ test/test_array_node.jl | 15 ++++++++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl index 6fc1d7d5..0c135c96 100644 --- a/src/ArrayNode.jl +++ b/src/ArrayNode.jl @@ -18,8 +18,6 @@ export ArrayNode return A{T}(undef, n) end -# Container for all nodes in a tree -# A is the vector array type (e.g., Vector, FixedSizeVector) mutable struct ArrayTree{T,D,A<:AbstractVector} degrees::A # Vector of UInt8 constants::A # Vector of Bool @@ -57,13 +55,11 @@ end # Default constructor using regular arrays ArrayTree{T,D}(n::Int) where {T,D} = ArrayTree{T,D,Vector}(n) -# ArrayNode is just a lightweight view into the ArrayTree mutable struct ArrayNode{T,D,A<:AbstractVector} <: AbstractExpressionNode{T,D} tree::ArrayTree{T,D,A} idx::Int8 end -# The clever part: getproperty just indexes into the arrays! function getproperty(n::ArrayNode, k::Symbol) tree = getfield(n, :tree) idx = getfield(n, :idx) @@ -430,8 +426,6 @@ function set_node!(dst::ArrayNode, src::ArrayNode) else dst.op = src.op - # Copy children - need to get D from somewhere - # Since both dst and src are ArrayNodes, we can get it from the type D = max_degree(typeof(dst)) child_indices = ntuple(i -> begin if i <= src.degree diff --git a/test/test_array_node.jl b/test/test_array_node.jl index 4f297655..49de706a 100644 --- a/test/test_array_node.jl +++ b/test/test_array_node.jl @@ -1,7 +1,8 @@ @testitem "ArrayNode interface with Vector" begin using DynamicExpressions - using DynamicExpressions: NodeInterface, ArrayNode + using DynamicExpressions: NodeInterface using Interfaces: Interfaces + const ArrayNode = DynamicExpressions.ArrayNode # Test with regular Vector x1 = ArrayNode{Float64,2,Vector}(; feature=1) @@ -33,8 +34,9 @@ end @testitem "ArrayNode with custom array type" begin using DynamicExpressions - using DynamicExpressions: NodeInterface, ArrayNode + using DynamicExpressions: NodeInterface using Interfaces: Interfaces + const ArrayNode = DynamicExpressions.ArrayNode # Test that ArrayNode works with any AbstractVector type # For production use with FixedSizeArrays, you'd need a wrapper @@ -69,8 +71,9 @@ end @testitem "ArrayNode interface on n-arity nodes" begin using DynamicExpressions - using DynamicExpressions: NodeInterface, ArrayNode + using DynamicExpressions: NodeInterface using Interfaces: Interfaces + const ArrayNode = DynamicExpressions.ArrayNode for D in (3, 4, 5) # Test with regular arrays @@ -98,7 +101,8 @@ end @testitem "ArrayNode basic operations" begin using DynamicExpressions - using DynamicExpressions: ArrayNode, OperatorEnum + using DynamicExpressions: OperatorEnum + const ArrayNode = DynamicExpressions.ArrayNode # Test with regular arrays (default) x1 = ArrayNode{Float64,2}(; feature=1) @@ -150,7 +154,8 @@ end @testitem "ArrayNode with Expressions" begin using DynamicExpressions - using DynamicExpressions: ArrayNode, Expression + using DynamicExpressions: Expression + const ArrayNode = DynamicExpressions.ArrayNode # Create a simple tree with default arrays x1 = ArrayNode{Float64,2}(; feature=1) From 90235877e916ceefb995c1affcf0d7a88a01ea11 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 29 Aug 2025 20:53:18 -0400 Subject: [PATCH 03/16] refactor: rewrite ArrayNode with StructArray --- Project.toml | 2 + src/ArrayNode.jl | 455 ++++++++++++++++++++++------------------ test/Project.toml | 1 + test/test_array_node.jl | 136 ++++++++---- 4 files changed, 351 insertions(+), 243 deletions(-) diff --git a/Project.toml b/Project.toml index c237d42a..de530589 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] @@ -39,6 +40,7 @@ MacroTools = "0.4, 0.5" Optim = "0.19, 1" PrecompileTools = "1" Reexport = "1" +StructArrays = "0.7.1" SymbolicUtils = "0.19, ^1.0.5, 2, 3" Zygote = "0.7" julia = "1.10" diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl index 0c135c96..9fb1962b 100644 --- a/src/ArrayNode.jl +++ b/src/ArrayNode.jl @@ -2,127 +2,146 @@ module ArrayNodeModule using ..NodeModule: AbstractExpressionNode, Nullable using ..UtilsModule: Undefined +using StructArrays: StructArray, StructVector import ..NodeModule: - constructorof, with_type_parameters, with_max_degree, - preserve_sharing, max_degree, default_allocator, - get_children, set_children!, unsafe_get_children, - tree_mapreduce, count_nodes, set_node!, any, copy_node + constructorof, + with_type_parameters, + with_max_degree, + preserve_sharing, + max_degree, + default_allocator, + get_children, + set_children!, + unsafe_get_children, + tree_mapreduce, + count_nodes, + set_node!, + any, + copy_node import Base: copy, hash, ==, getproperty, setproperty!, eltype export ArrayNode -# Helper function to create array of specific type -@inline function create_array(::Type{A}, ::Type{T}, n::Int) where {A,T} - return A{T}(undef, n) +# Node data struct +struct NodeData{T,D} + degree::UInt8 + constant::Bool + val::T + feature::UInt16 + op::UInt8 + children::NTuple{D,Int8} end -mutable struct ArrayTree{T,D,A<:AbstractVector} - degrees::A # Vector of UInt8 - constants::A # Vector of Bool - vals::A # Vector of T - features::A # Vector of UInt16 - ops::A # Vector of UInt8 - children::A # Vector of NTuple{D,Int8} - +# Constructor for empty node +function NodeData{T,D}() where {T,D} + return NodeData{T,D}( + UInt8(0), true, zero(T), UInt16(0), UInt8(0), ntuple(_ -> Int8(-1), Val(D)) + ) +end + +mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}} + nodes::S root_idx::Int8 n_nodes::Int8 - free_list::A # Vector of Int8 + free_list::Vector{Int8} free_count::Int8 - - function ArrayTree{T,D,A}(n::Int) where {T,D,A<:AbstractVector} - tree = new{T,D,A}( - create_array(A, UInt8, n), - create_array(A, Bool, n), - create_array(A, T, n), - create_array(A, UInt16, n), - create_array(A, UInt8, n), - create_array(A, NTuple{D,Int8}, n), - Int8(0), Int8(0), - create_array(A, Int8, n), - Int8(n) - ) - # Initialize free list and children + + function ArrayTree{T,D}(n::Int) where {T,D} + # Create a StructVector with pre-allocated arrays + nodes = StructVector{NodeData{T,D}}(undef, n) + # Initialize all nodes to default values + for i in 1:n + nodes.degree[i] = UInt8(0) + nodes.constant[i] = true + nodes.val[i] = zero(T) + nodes.feature[i] = UInt16(0) + nodes.op[i] = UInt8(0) + nodes.children[i] = ntuple(_ -> Int8(-1), Val(D)) + end + + S = typeof(nodes) + tree = new{T,D,S}(nodes, Int8(0), Int8(0), Vector{Int8}(undef, n), Int8(n)) + # Initialize free list for i in 1:n tree.free_list[i] = Int8(i) - tree.children[i] = ntuple(_ -> Int8(-1), Val(D)) end return tree end end -# Default constructor using regular arrays -ArrayTree{T,D}(n::Int) where {T,D} = ArrayTree{T,D,Vector}(n) - -mutable struct ArrayNode{T,D,A<:AbstractVector} <: AbstractExpressionNode{T,D} - tree::ArrayTree{T,D,A} +struct ArrayNode{T,D,S} <: AbstractExpressionNode{T,D} + tree::ArrayTree{T,D,S} idx::Int8 end -function getproperty(n::ArrayNode, k::Symbol) +function getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) - + nodes = getfield(tree, :nodes) + if k == :tree return tree elseif k == :idx return idx elseif k == :degree - return @inbounds tree.degrees[idx] + return @inbounds nodes.degree[idx] elseif k == :constant - return @inbounds tree.constants[idx] + return @inbounds nodes.constant[idx] elseif k == :val - return @inbounds tree.vals[idx] + return @inbounds nodes.val[idx] elseif k == :feature - return @inbounds tree.features[idx] + return @inbounds nodes.feature[idx] elseif k == :op - return @inbounds tree.ops[idx] + return @inbounds nodes.op[idx] elseif k == :children # Return tuple of child ArrayNodes wrapped in Nullable - D = max_degree(typeof(n)) - return ntuple(i -> begin - child_idx = @inbounds tree.children[idx][i] + return ntuple(Val(D)) do i + child_idx = @inbounds nodes.children[idx][i] if child_idx < 0 Nullable(true, n) # Poison node else - Nullable(false, ArrayNode(tree, child_idx)) + Nullable(false, ArrayNode{T,D,S}(tree, child_idx)) end - end, Val(D)) + end elseif k == :l # Left child for compatibility - child_idx = @inbounds tree.children[idx][1] - return child_idx < 0 ? error("No left child") : ArrayNode(tree, child_idx) + child_idx = @inbounds nodes.children[idx][1] + return child_idx < 0 ? error("No left child") : ArrayNode{T,D,S}(tree, child_idx) elseif k == :r # Right child for compatibility - child_idx = @inbounds tree.children[idx][2] - return child_idx < 0 ? error("No right child") : ArrayNode(tree, child_idx) + child_idx = @inbounds nodes.children[idx][2] + return child_idx < 0 ? error("No right child") : ArrayNode{T,D,S}(tree, child_idx) else error("Unknown field $k") end end -function setproperty!(n::ArrayNode, k::Symbol, v) +function setproperty!(n::ArrayNode{T,D,S}, k::Symbol, v) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) - + nodes = getfield(tree, :nodes) + if k == :degree - @inbounds tree.degrees[idx] = v + @inbounds nodes.degree[idx] = v elseif k == :constant - @inbounds tree.constants[idx] = v + @inbounds nodes.constant[idx] = v elseif k == :val - @inbounds tree.vals[idx] = v + @inbounds nodes.val[idx] = v elseif k == :feature - @inbounds tree.features[idx] = v + @inbounds nodes.feature[idx] = v elseif k == :op - @inbounds tree.ops[idx] = v + @inbounds nodes.op[idx] = v elseif k == :l if isa(v, ArrayNode) - children = tree.children[idx] - @inbounds tree.children[idx] = (getfield(v, :idx), children[2:end]...) + children = nodes.children[idx] + @inbounds nodes.children[idx] = (getfield(v, :idx), children[2:end]...) end elseif k == :r if isa(v, ArrayNode) - children = tree.children[idx] - @inbounds tree.children[idx] = (children[1], getfield(v, :idx), children[3:end]...) + children = nodes.children[idx] + @inbounds nodes.children[idx] = ( + children[1], getfield(v, :idx), children[3:end]... + ) end else error("Cannot set field $k") @@ -142,58 +161,57 @@ end function free_node!(tree::ArrayTree, idx::Int8) tree.free_count += 1 tree.free_list[tree.free_count] = idx - tree.n_nodes -= 1 + return tree.n_nodes -= 1 end # Default constructors - now include array type parameters -ArrayNode{T,D,A}(n::Int=32) where {T,D,A} = ArrayNode{T,D,A}(Undefined; allocator=ArrayTree{T,D,A}(n)) -ArrayNode{T,D}(n::Int) where {T,D} = ArrayNode{T,D,Vector}(n) -ArrayNode{T}(n::Int) where {T} = ArrayNode{T,2,Vector}(n) +ArrayNode{T,D}(n::Int) where {T,D} = ArrayNode{T,D}(Undefined; allocator=ArrayTree{T,D}(n)) +ArrayNode{T}(n::Int) where {T} = ArrayNode{T,2}(n) # Keyword constructors for partial type signatures -ArrayNode{T,D}(; kwargs...) where {T,D} = ArrayNode{T,D,Vector}(Undefined; kwargs...) -ArrayNode{T}(; kwargs...) where {T} = ArrayNode{T,2,Vector}(Undefined; kwargs...) +ArrayNode{T,D}(; kwargs...) where {T,D} = ArrayNode{T,D}(Undefined; kwargs...) +ArrayNode{T}(; kwargs...) where {T} = ArrayNode{T,2}(; kwargs...) # Constructor with keyword arguments - matches Node interface -function ArrayNode{T,D,A}( - ::Type{T1}=Undefined; +function ArrayNode{T,D}( + ::Type{T1}; val=nothing, - feature=nothing, + feature=nothing, op=nothing, l=nothing, r=nothing, children=nothing, - allocator=nothing -) where {T,D,A<:AbstractVector,T1} + allocator=nothing, +) where {T,D,T1} # Determine tree source # Always create a new tree unless an allocator is explicitly provided tree = if !isnothing(allocator) && isa(allocator, ArrayTree) allocator else # Just use a reasonable default size - ArrayTree{T,D,A}(64) + ArrayTree{T,D}(64) end - + idx = allocate_node!(tree) # Only set root_idx if this tree is new (no nodes allocated yet except this one) if tree.n_nodes == 1 tree.root_idx = idx end - + if !isnothing(val) - tree.degrees[idx] = 0 - tree.constants[idx] = true - tree.vals[idx] = val - return ArrayNode(tree, idx) + tree.nodes.degree[idx] = 0 + tree.nodes.constant[idx] = true + tree.nodes.val[idx] = val + return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx) end - + if !isnothing(feature) - tree.degrees[idx] = 0 - tree.constants[idx] = false - tree.features[idx] = feature - return ArrayNode(tree, idx) + tree.nodes.degree[idx] = 0 + tree.nodes.constant[idx] = false + tree.nodes.feature[idx] = feature + return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx) end - + if !isnothing(op) _children = if !isnothing(l) && isnothing(r) (l,) @@ -202,72 +220,77 @@ function ArrayNode{T,D,A}( else children end - + if !isnothing(_children) degree = length(_children) - tree.degrees[idx] = degree - tree.ops[idx] = op - + tree.nodes.degree[idx] = degree + tree.nodes.op[idx] = op + # Copy children into this tree - child_indices = ntuple(i -> begin - if i <= length(_children) - child = _children[i] - if isa(child, ArrayNode) - child_tree = getfield(child, :tree) - child_idx = getfield(child, :idx) - if child_tree === tree - # Same tree - just link - child_idx + child_indices = ntuple( + i -> begin + if i <= length(_children) + child = _children[i] + if isa(child, ArrayNode) + child_tree = getfield(child, :tree) + child_idx = getfield(child, :idx) + if child_tree === tree + # Same tree - just link + child_idx + else + # Different tree - copy + copy_subtree!(tree, child_tree, child_idx) + end else - # Different tree - copy - copy_subtree!(tree, child_tree, child_idx) + Int8(-1) end else Int8(-1) end - else - Int8(-1) - end - end, Val(D)) - tree.children[idx] = child_indices - - return ArrayNode(tree, idx) + end, + Val(D), + ) + tree.nodes.children[idx] = child_indices + + return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx) end end - + # Default: empty constant - tree.degrees[idx] = 0 - tree.constants[idx] = true - tree.vals[idx] = zero(T) - return ArrayNode(tree, idx) + tree.nodes.degree[idx] = 0 + tree.nodes.constant[idx] = true + tree.nodes.val[idx] = zero(T) + return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx) end -function copy_subtree!(dst::ArrayTree{T,D,A}, src::ArrayTree{T,D,A}, src_idx::Int8) where {T,D,A} +function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::Int8) where {T,D} dst_idx = allocate_node!(dst) - + @inbounds begin - dst.degrees[dst_idx] = src.degrees[src_idx] - dst.constants[dst_idx] = src.constants[src_idx] - dst.vals[dst_idx] = src.vals[src_idx] - dst.features[dst_idx] = src.features[src_idx] - dst.ops[dst_idx] = src.ops[src_idx] + dst.nodes.degree[dst_idx] = src.nodes.degree[src_idx] + dst.nodes.constant[dst_idx] = src.nodes.constant[src_idx] + dst.nodes.val[dst_idx] = src.nodes.val[src_idx] + dst.nodes.feature[dst_idx] = src.nodes.feature[src_idx] + dst.nodes.op[dst_idx] = src.nodes.op[src_idx] end - - degree = @inbounds src.degrees[src_idx] - child_indices = ntuple(i -> begin - if i <= degree - child_idx = @inbounds src.children[src_idx][i] - if child_idx >= 0 - copy_subtree!(dst, src, child_idx) + + degree = @inbounds src.nodes.degree[src_idx] + child_indices = ntuple( + i -> begin + if i <= degree + child_idx = @inbounds src.nodes.children[src_idx][i] + if child_idx >= 0 + copy_subtree!(dst, src, child_idx) + else + Int8(-1) + end else Int8(-1) end - else - Int8(-1) - end - end, Val(D)) - dst.children[dst_idx] = child_indices - + end, Val(D) + ) + dst.nodes.children[dst_idx] = child_indices + return dst_idx end @@ -282,36 +305,39 @@ max_degree(n::ArrayNode) = max_degree(typeof(n)) preserve_sharing(::Type{<:ArrayNode}) = false constructorof(::Type{<:ArrayNode}) = ArrayNode -with_type_parameters(::Type{<:ArrayNode}, ::Type{T}) where {T} = ArrayNode{T,2,Vector} -with_max_degree(::Type{<:ArrayNode{T,D,A}}, ::Val{D2}) where {T,D,A,D2} = ArrayNode{T,D2,A} -default_allocator(::Type{ArrayNode{T,D,A}}) where {T,D,A} = ArrayTree{T,D,A}(32) +with_type_parameters(::Type{<:ArrayNode}, ::Type{T}) where {T} = ArrayNode{T,2} +with_max_degree(::Type{<:ArrayNode{T,D}}, ::Val{D2}) where {T,D,D2} = ArrayNode{T,D2} +default_allocator(::Type{ArrayNode{T,D}}) where {T,D} = ArrayTree{T,D}(32) # get_children and set_children! -function unsafe_get_children(n::ArrayNode{T,D}) where {T,D} +function unsafe_get_children(n::ArrayNode{T,D,S}) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) - return ntuple(i -> begin - child_idx = @inbounds tree.children[idx][i] - if child_idx < 0 - Nullable(true, n) - else - Nullable(false, ArrayNode(tree, child_idx)) - end - end, Val(D)) + return ntuple( + i -> begin + child_idx = @inbounds tree.nodes.children[idx][i] + if child_idx < 0 + Nullable(true, n) + else + Nullable(false, ArrayNode{T,D,typeof(tree.nodes)}(tree, child_idx)) + end + end, + Val(D), + ) end -function get_children(n::ArrayNode{T,D}, ::Val{d}) where {T,D,d} +function get_children(n::ArrayNode{T,D,S}, ::Val{d}) where {T,D,S,d} tree = getfield(n, :tree) idx = getfield(n, :idx) return ntuple(i -> begin - child_idx = @inbounds tree.children[idx][i] - ArrayNode(tree, child_idx) + child_idx = @inbounds tree.nodes.children[idx][i] + ArrayNode{T,D,typeof(tree.nodes)}(tree, child_idx) end, Val(Int(d))) end get_children(n::ArrayNode, d::Integer) = get_children(n, Val(d)) -function set_children!(n::ArrayNode{T,D,A}, cs::Tuple) where {T,D,A} +function set_children!(n::ArrayNode{T,D,S}, cs::Tuple) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) child_indices = ntuple(i -> begin @@ -326,25 +352,25 @@ function set_children!(n::ArrayNode{T,D,A}, cs::Tuple) where {T,D,A} Int8(-1) end end, Val(D)) - tree.children[idx] = child_indices + return tree.nodes.children[idx] = child_indices end # Copy -function copy_node(n::ArrayNode{T,D,A}; break_sharing::Val{BS}=Val(false)) where {T,D,A,BS} +function copy_node(n::ArrayNode{T,D,S}; break_sharing::Val{BS}=Val(false)) where {T,D,S,BS} tree = getfield(n, :tree) idx = getfield(n, :idx) - + # Count nodes to determine tree size needed node_count = count_subtree(tree, idx) # Add some buffer space tree_size = max(32, node_count * 2) - + # Create new tree for the copy - new_tree = ArrayTree{T,D,A}(tree_size) + new_tree = ArrayTree{T,D}(tree_size) new_idx = copy_subtree!(new_tree, tree, idx) new_tree.root_idx = new_idx - - return ArrayNode(new_tree, new_idx) + + return ArrayNode{T,D,typeof(new_tree.nodes)}(new_tree, new_idx) end copy(n::ArrayNode) = copy_node(n) @@ -357,9 +383,9 @@ end function count_subtree(tree::ArrayTree, idx::Int8) count = 1 - degree = @inbounds tree.degrees[idx] + degree = @inbounds tree.nodes.degree[idx] for i in 1:degree - child_idx = @inbounds tree.children[idx][i] + child_idx = @inbounds tree.nodes.children[idx][i] if child_idx >= 0 count += count_subtree(tree, child_idx) end @@ -370,7 +396,7 @@ end # Equality and hash function ==(a::ArrayNode, b::ArrayNode) a.degree != b.degree && return false - + if a.degree == 0 a.constant != b.constant && return false if a.constant @@ -380,9 +406,9 @@ function ==(a::ArrayNode, b::ArrayNode) end else a.op != b.op && return false - + # Compare children recursively - for i in 1:a.degree + for i in 1:(a.degree) ca = get_children(a, Val(Int(a.degree)))[i] cb = get_children(b, Val(Int(b.degree)))[i] ca != cb && return false @@ -399,10 +425,12 @@ function hash(n::ArrayNode, h::UInt=zero(UInt)) return hash((1, n.feature), h) end else - children_hashes = ntuple(i -> begin - child = get_children(n, Val(Int(n.degree)))[i] - hash(child, h) - end, Val(Int(n.degree))) + children_hashes = ntuple( + i -> begin + child = get_children(n, Val(Int(n.degree)))[i] + hash(child, h) + end, Val(Int(n.degree)) + ) return hash((n.degree + 1, n.op, children_hashes), h) end end @@ -413,9 +441,9 @@ function set_node!(dst::ArrayNode, src::ArrayNode) src_tree = getfield(src, :tree) dst_idx = getfield(dst, :idx) src_idx = getfield(src, :idx) - + dst.degree = src.degree - + if src.degree == 0 dst.constant = src.constant if src.constant @@ -425,81 +453,94 @@ function set_node!(dst::ArrayNode, src::ArrayNode) end else dst.op = src.op - + D = max_degree(typeof(dst)) - child_indices = ntuple(i -> begin - if i <= src.degree - child_idx = @inbounds src_tree.children[src_idx][i] - if child_idx >= 0 - if dst_tree === src_tree - # Same tree - child_idx + child_indices = ntuple( + i -> begin + if i <= src.degree + child_idx = @inbounds src_tree.nodes.children[src_idx][i] + if child_idx >= 0 + if dst_tree === src_tree + # Same tree + child_idx + else + # Different tree - need to copy + copy_subtree!(dst_tree, src_tree, child_idx) + end else - # Different tree - need to copy - copy_subtree!(dst_tree, src_tree, child_idx) + Int8(-1) end else Int8(-1) end - else - Int8(-1) - end - end, Val(D)) - dst_tree.children[dst_idx] = child_indices + end, + Val(D), + ) + dst_tree.nodes.children[dst_idx] = child_indices end - + return nothing end # tree_mapreduce and any -function tree_mapreduce(f::F, op::G, n::ArrayNode, ::Type{RT}=Any; f_on_shared=nothing, break_sharing=Val(false), kwargs...) where {F<:Function,G<:Function,RT} +function tree_mapreduce( + f::F, + op::G, + n::ArrayNode, + (::Type{RT})=Any; + f_on_shared=nothing, + break_sharing=Val(false), + kwargs..., +) where {F<:Function,G<:Function,RT} tree = getfield(n, :tree) return mapreduce_impl(f, op, tree, getfield(n, :idx)) end -function mapreduce_impl(f::F, op::G, tree::ArrayTree, idx::Int8) where {F,G} - degree = @inbounds tree.degrees[idx] - node = ArrayNode(tree, idx) +function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::Int8) where {F,G,T,D,S} + degree = @inbounds tree.nodes.degree[idx] + node = ArrayNode{T,D,S}(tree, idx) result = f(node) - + if degree > 0 - child_results = ntuple(i -> begin - child_idx = @inbounds tree.children[idx][i] - if child_idx >= 0 - mapreduce_impl(f, op, tree, child_idx) - else - nothing - end - end, Val(Int(degree))) - + child_results = ntuple( + i -> begin + child_idx = @inbounds tree.nodes.children[idx][i] + if child_idx >= 0 + mapreduce_impl(f, op, tree, child_idx) + else + nothing + end + end, Val(Int(degree)) + ) + # Filter out nothings and apply op valid_results = filter(x -> !isnothing(x), child_results) if !isempty(valid_results) return op(result, valid_results...) end end - + return result end -function any(f::F, n::ArrayNode) where F<:Function +function any(f::F, n::ArrayNode) where {F<:Function} tree = getfield(n, :tree) return any_impl(f, tree, getfield(n, :idx)) end -function any_impl(f::F, tree::ArrayTree, idx::Int8) where F - node = ArrayNode(tree, idx) +function any_impl(f::F, tree::ArrayTree{T,D,S}, idx::Int8) where {F,T,D,S} + node = ArrayNode{T,D,S}(tree, idx) f(node) && return true - - degree = @inbounds tree.degrees[idx] + + degree = @inbounds tree.nodes.degree[idx] for i in 1:degree - child_idx = @inbounds tree.children[idx][i] + child_idx = @inbounds tree.nodes.children[idx][i] if child_idx >= 0 && any_impl(f, tree, child_idx) return true end end - + return false end -end # module \ No newline at end of file +end # module diff --git a/test/Project.toml b/test/Project.toml index 61ff11c1..6ecd3917 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/test/test_array_node.jl b/test/test_array_node.jl index 49de706a..efa4e0a1 100644 --- a/test/test_array_node.jl +++ b/test/test_array_node.jl @@ -11,16 +11,17 @@ operators = OperatorEnum(; binary_operators=[+, *], unary_operators=[sin]) # Create test trees matching the pattern in test_node_interface.jl - tree_branch_deg2 = ArrayNode{Float64,2,Vector}(; op=1, + tree_branch_deg2 = ArrayNode{Float64,2,Vector}(; + op=1, l=x1, - r=ArrayNode{Float64,2,Vector}(; op=1, - l=ArrayNode{Float64,2,Vector}(; op=2, - l=x2, - r=ArrayNode{Float64,2,Vector}(; val=3.5) - ) - ) + r=ArrayNode{Float64,2,Vector}(; + op=1, + l=ArrayNode{Float64,2,Vector}(; + op=2, l=x2, r=ArrayNode{Float64,2,Vector}(; val=3.5) + ), + ), ) # x1 + sin(x2 * 3.5) - + tree_branch_deg1 = ArrayNode{Float64,2,Vector}(; op=1, l=x1) # sin(x1) tree_leaf_feature = x1 tree_leaf_constant = ArrayNode{Float64,2,Vector}(; val=1.0) @@ -37,27 +38,28 @@ end using DynamicExpressions: NodeInterface using Interfaces: Interfaces const ArrayNode = DynamicExpressions.ArrayNode - + # Test that ArrayNode works with any AbstractVector type # For production use with FixedSizeArrays, you'd need a wrapper # that handles mixed element types properly - - x1 = ArrayNode{Float64,2,Vector}(; feature=1) + + x1 = ArrayNode{Float64,2,Vector}(; feature=1) x2 = ArrayNode{Float64,2,Vector}(; feature=2) operators = OperatorEnum(; binary_operators=[+, *], unary_operators=[sin]) # Create test trees - tree_branch_deg2 = ArrayNode{Float64,2,Vector}(; op=1, + tree_branch_deg2 = ArrayNode{Float64,2,Vector}(; + op=1, l=x1, - r=ArrayNode{Float64,2,Vector}(; op=1, - l=ArrayNode{Float64,2,Vector}(; op=2, - l=x2, - r=ArrayNode{Float64,2,Vector}(; val=3.5) - ) - ) + r=ArrayNode{Float64,2,Vector}(; + op=1, + l=ArrayNode{Float64,2,Vector}(; + op=2, l=x2, r=ArrayNode{Float64,2,Vector}(; val=3.5) + ), + ), ) - + tree_branch_deg1 = ArrayNode{Float64,2,Vector}(; op=1, l=x1) tree_leaf_feature = x1 tree_leaf_constant = ArrayNode{Float64,2,Vector}(; val=1.0) @@ -84,15 +86,19 @@ end operators = isempty(pairs) ? OperatorEnum(1 => ()) : OperatorEnum(pairs[1], pairs[2:end]...) DynamicExpressions.OperatorEnumConstructionModule.empty_all_globals!() - + let tree = ArrayNode{Float64,D,Vector}(; op=2, children=(x[1], x[2])) # * if D > 2 fma_idx = 1 - tree = ArrayNode{Float64,D,Vector}(; op=fma_idx, children=(tree, x[1], x[2])) # fma + tree = ArrayNode{Float64,D,Vector}(; + op=fma_idx, children=(tree, x[1], x[2]) + ) # fma end if D > 3 idx_max = 1 - tree = ArrayNode{Float64,D,Vector}(; op=idx_max, children=(tree, x[1], x[2], x[3])) # max + tree = ArrayNode{Float64,D,Vector}(; + op=idx_max, children=(tree, x[1], x[2], x[3]) + ) # max end @test Interfaces.test(NodeInterface, ArrayNode, [tree]) end @@ -102,49 +108,50 @@ end @testitem "ArrayNode basic operations" begin using DynamicExpressions using DynamicExpressions: OperatorEnum + using AllocCheck: @check_allocs const ArrayNode = DynamicExpressions.ArrayNode # Test with regular arrays (default) x1 = ArrayNode{Float64,2}(; feature=1) x2 = ArrayNode{Float64,2}(; feature=2) c = ArrayNode{Float64,2}(; val=3.5) - + # Test basic properties @test x1.degree == 0 @test x1.feature == 1 @test !x1.constant - + @test c.degree == 0 @test c.val == 3.5 @test c.constant - + # Test tree construction mul = ArrayNode{Float64,2}(; op=3, l=x2, r=c) @test mul.degree == 2 @test mul.op == 3 - + sin_expr = ArrayNode{Float64,2}(; op=1, l=mul) @test sin_expr.degree == 1 - + tree = ArrayNode{Float64,2}(; op=1, l=x1, r=sin_expr) @test tree.degree == 2 - + # Test copy tree_copy = copy(tree) @test tree == tree_copy @test tree !== tree_copy - + # Test hash @test hash(tree) == hash(tree_copy) - + # Test count_nodes @test count_nodes(tree) == 6 # tree, x1, sin_expr, mul, x2, c - + # Test string conversion operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]) str = string_tree(tree, operators) @test str == "x1 + sin(x2 * 3.5)" - + # Test evaluation X = [1.0 2.0; 0.5 1.0] # 2 features, 2 samples result = eval_tree_array(tree, X, operators) @@ -152,6 +159,64 @@ end @test all(abs.(result[1] .- expected) .< 1e-10) end +@testitem "ArrayNode allocation tests" begin + using DynamicExpressions + using DynamicExpressions: OperatorEnum, eval_tree_array + using AllocCheck: @check_allocs + const ArrayNode = DynamicExpressions.ArrayNode + + # Create a tree with preallocated storage + allocator = DynamicExpressions.ArrayNodeModule.ArrayTree{Float64,2}(100) + x1 = ArrayNode{Float64,2}(; feature=1, allocator=allocator) + x2 = ArrayNode{Float64,2}(; feature=2, allocator=allocator) + c = ArrayNode{Float64,2}(; val=3.5, allocator=allocator) + + # Build tree using same allocator + mul = ArrayNode{Float64,2}(; op=3, l=x2, r=c, allocator=allocator) + sin_expr = ArrayNode{Float64,2}(; op=1, l=mul, allocator=allocator) + tree = ArrayNode{Float64,2}(; op=1, l=x1, r=sin_expr, allocator=allocator) + + operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]) + X = [1.0 2.0; 0.5 1.0] # 2 features, 2 samples + + # Warm up + result = eval_tree_array(tree, X, operators) + + # Test that evaluation doesn't allocate + @check_allocs eval_tree_array(tree, X, operators) = eval_tree_array(tree, X, operators) + + # Test that property access doesn't allocate + @check_allocs get_degree(n) = n.degree + @check_allocs get_val(n) = n.val + @check_allocs get_feature(n) = n.feature + @check_allocs get_op(n) = n.op + + get_degree(tree) + get_val(c) + get_feature(x1) + get_op(tree) + + # Test that count_nodes doesn't allocate (after warm-up) + count_nodes(tree) + @check_allocs count_nodes(tree) = count_nodes(tree) + + # Test that tree traversal doesn't allocate + function traverse_tree(n::ArrayNode) + sum = n.degree + if n.degree > 0 + children = DynamicExpressions.NodeModule.get_children(n, Val(Int(n.degree))) + for child in children + sum += traverse_tree(child) + end + end + return sum + end + + # Warm up + traverse_tree(tree) + @check_allocs traverse_tree(tree) = traverse_tree(tree) +end + @testitem "ArrayNode with Expressions" begin using DynamicExpressions using DynamicExpressions: Expression @@ -161,17 +226,16 @@ end x1 = ArrayNode{Float64,2}(; feature=1) c = ArrayNode{Float64,2}(; val=2.0) tree = ArrayNode{Float64,2}(; op=1, l=x1, r=c) - + operators = OperatorEnum(; binary_operators=[+, -, *, /]) - + # Test Expression conversion expr = Expression(tree; operators=operators, variable_names=["x1", "x2"]) @test string(expr) == "x1 + 2.0" - + # Test evaluation through Expression X = [1.0 2.0 3.0] # 1 feature, 3 samples result = expr(X) expected = vec(X .+ 2.0) # Convert to vector to match result shape @test all(abs.(result .- expected) .< 1e-10) end - From 3411eddd4b36cb29f298c58ba1a3d59a33fbc925 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 29 Aug 2025 21:06:41 -0400 Subject: [PATCH 04/16] test: correctly include array node tests --- test/unittest.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/unittest.jl b/test/unittest.jl index 78e0dcd7..41a17bf6 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -129,6 +129,7 @@ include("test_parse.jl") include("test_parametric_expression.jl") include("test_operator_construction_edgecases.jl") include("test_node_interface.jl") +include("test_array_node.jl") include("test_expression_math.jl") include("test_structured_expression.jl") include("test_zygote_gradient_wrapper.jl") From c56006cbce00fef4c5baee3a05b19580acff5432 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 29 Aug 2025 21:59:43 -0400 Subject: [PATCH 05/16] feat: allow arbitrary base array in array node --- src/ArrayNode.jl | 114 ++++++++++++++++++++++++---------------- test/Project.toml | 1 + test/test_array_node.jl | 63 ++++++++++++++++++++++ 3 files changed, 134 insertions(+), 44 deletions(-) diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl index 9fb1962b..406dfa7e 100644 --- a/src/ArrayNode.jl +++ b/src/ArrayNode.jl @@ -20,8 +20,6 @@ import ..NodeModule: any, copy_node -import Base: copy, hash, ==, getproperty, setproperty!, eltype - export ArrayNode # Node data struct @@ -42,15 +40,31 @@ function NodeData{T,D}() where {T,D} end mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}} - nodes::S + const nodes::S root_idx::Int8 n_nodes::Int8 - free_list::Vector{Int8} + const free_list::Vector{Int8} free_count::Int8 - function ArrayTree{T,D}(n::Int) where {T,D} - # Create a StructVector with pre-allocated arrays - nodes = StructVector{NodeData{T,D}}(undef, n) + function ArrayTree{T,D}(n::Int; array_type::Type{<:AbstractVector}=Vector) where {T,D} + # Create backing arrays of the specified type + degree = array_type{UInt8}(undef, n) + constant = array_type{Bool}(undef, n) + val = array_type{T}(undef, n) + feature = array_type{UInt16}(undef, n) + op = array_type{UInt8}(undef, n) + children = array_type{NTuple{D,Int8}}(undef, n) + + # Create a StructVector from the backing arrays + nodes = StructVector{NodeData{T,D}}(( + degree=degree, + constant=constant, + val=val, + feature=feature, + op=op, + children=children, + )) + # Initialize all nodes to default values for i in 1:n nodes.degree[i] = UInt8(0) @@ -76,7 +90,7 @@ struct ArrayNode{T,D,S} <: AbstractExpressionNode{T,D} idx::Int8 end -function getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} +function Base.getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) nodes = getfield(tree, :nodes) @@ -116,7 +130,7 @@ function getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} end end -function setproperty!(n::ArrayNode{T,D,S}, k::Symbol, v) where {T,D,S} +function Base.setproperty!(n::ArrayNode{T,D,S}, k::Symbol, v) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) nodes = getfield(tree, :nodes) @@ -132,17 +146,13 @@ function setproperty!(n::ArrayNode{T,D,S}, k::Symbol, v) where {T,D,S} elseif k == :op @inbounds nodes.op[idx] = v elseif k == :l - if isa(v, ArrayNode) - children = nodes.children[idx] - @inbounds nodes.children[idx] = (getfield(v, :idx), children[2:end]...) - end + isa(v, ArrayNode) || error("Cannot set left child to non-ArrayNode") + children = nodes.children[idx] + @inbounds nodes.children[idx] = (getfield(v, :idx), children[2:end]...) elseif k == :r - if isa(v, ArrayNode) - children = nodes.children[idx] - @inbounds nodes.children[idx] = ( - children[1], getfield(v, :idx), children[3:end]... - ) - end + isa(v, ArrayNode) || error("Cannot set right child to non-ArrayNode") + children = nodes.children[idx] + @inbounds nodes.children[idx] = (children[1], getfield(v, :idx), children[3:end]...) else error("Cannot set field $k") end @@ -165,11 +175,16 @@ function free_node!(tree::ArrayTree, idx::Int8) end # Default constructors - now include array type parameters -ArrayNode{T,D}(n::Int) where {T,D} = ArrayNode{T,D}(Undefined; allocator=ArrayTree{T,D}(n)) -ArrayNode{T}(n::Int) where {T} = ArrayNode{T,2}(n) +function ArrayNode{T,D}(n::Int; array_type::Type{<:AbstractVector}=Vector) where {T,D} + return ArrayNode{T,D}(Undefined; allocator=ArrayTree{T,D}(n; array_type=array_type)) +end +function ArrayNode{T}(n::Int; array_type::Type{<:AbstractVector}=Vector) where {T} + return ArrayNode{T,2}(n; array_type=array_type) +end # Keyword constructors for partial type signatures ArrayNode{T,D}(; kwargs...) where {T,D} = ArrayNode{T,D}(Undefined; kwargs...) +ArrayNode{T,D,S}(; kwargs...) where {T,D,S} = ArrayNode{T,D}(Undefined; kwargs...) ArrayNode{T}(; kwargs...) where {T} = ArrayNode{T,2}(; kwargs...) # Constructor with keyword arguments - matches Node interface @@ -184,11 +199,11 @@ function ArrayNode{T,D}( allocator=nothing, ) where {T,D,T1} # Determine tree source - # Always create a new tree unless an allocator is explicitly provided tree = if !isnothing(allocator) && isa(allocator, ArrayTree) allocator else - # Just use a reasonable default size + # Default size of 64 nodes for small expressions + # This is wasteful if building incrementally, but matches Node semantics ArrayTree{T,D}(64) end @@ -213,6 +228,7 @@ function ArrayNode{T,D}( end if !isnothing(op) + # DEBUG: op=$op, l is nothing? $(isnothing(l)), r is nothing? $(isnothing(r)) _children = if !isnothing(l) && isnothing(r) (l,) elseif !isnothing(l) && !isnothing(r) @@ -222,6 +238,7 @@ function ArrayNode{T,D}( end if !isnothing(_children) + # DEBUG: Building node with children, length=length(_children) degree = length(_children) tree.nodes.degree[idx] = degree tree.nodes.op[idx] = op @@ -231,6 +248,7 @@ function ArrayNode{T,D}( i -> begin if i <= length(_children) child = _children[i] + # DEBUG: Processing child $i, isa ArrayNode? isa(child, ArrayNode) if isa(child, ArrayNode) child_tree = getfield(child, :tree) child_idx = getfield(child, :idx) @@ -239,7 +257,16 @@ function ArrayNode{T,D}( child_idx else # Different tree - copy - copy_subtree!(tree, child_tree, child_idx) + new_idx = copy_subtree!(tree, child_tree, child_idx) + # DEBUG + # println("Copied child from idx $child_idx to new idx $new_idx") + # println(" Original: constant=", child_tree.nodes.constant[child_idx], + # child_tree.nodes.constant[child_idx] ? ", val=" : ", feature=", + # child_tree.nodes.constant[child_idx] ? child_tree.nodes.val[child_idx] : child_tree.nodes.feature[child_idx]) + # println(" Copied: constant=", tree.nodes.constant[new_idx], + # tree.nodes.constant[new_idx] ? ", val=" : ", feature=", + # tree.nodes.constant[new_idx] ? tree.nodes.val[new_idx] : tree.nodes.feature[new_idx]) + new_idx end else Int8(-1) @@ -295,8 +322,8 @@ function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::Int8) end # Core interface implementations -eltype(::Type{<:ArrayNode{T}}) where {T} = T -eltype(::ArrayNode{T}) where {T} = T +Base.eltype(::Type{<:ArrayNode{T}}) where {T} = T +Base.eltype(::ArrayNode{T}) where {T} = T max_degree(::Type{<:ArrayNode}) = 2 max_degree(::Type{<:ArrayNode{T,D}}) where {T,D} = D @@ -307,7 +334,11 @@ preserve_sharing(::Type{<:ArrayNode}) = false constructorof(::Type{<:ArrayNode}) = ArrayNode with_type_parameters(::Type{<:ArrayNode}, ::Type{T}) where {T} = ArrayNode{T,2} with_max_degree(::Type{<:ArrayNode{T,D}}, ::Val{D2}) where {T,D,D2} = ArrayNode{T,D2} -default_allocator(::Type{ArrayNode{T,D}}) where {T,D} = ArrayTree{T,D}(32) +function default_allocator( + ::Type{ArrayNode{T,D}}; array_type::Type{<:AbstractVector}=Vector +) where {T,D} + return ArrayTree{T,D}(32; array_type=array_type) +end # get_children and set_children! function unsafe_get_children(n::ArrayNode{T,D,S}) where {T,D,S} @@ -361,7 +392,7 @@ function copy_node(n::ArrayNode{T,D,S}; break_sharing::Val{BS}=Val(false)) where idx = getfield(n, :idx) # Count nodes to determine tree size needed - node_count = count_subtree(tree, idx) + node_count = count_nodes(n) # Add some buffer space tree_size = max(32, node_count * 2) @@ -373,28 +404,23 @@ function copy_node(n::ArrayNode{T,D,S}; break_sharing::Val{BS}=Val(false)) where return ArrayNode{T,D,typeof(new_tree.nodes)}(new_tree, new_idx) end -copy(n::ArrayNode) = copy_node(n) +Base.copy(n::ArrayNode) = copy_node(n) -# count_nodes +# count_nodes - optimized version that checks if we're at root function count_nodes(n::ArrayNode) tree = getfield(n, :tree) - return count_subtree(tree, getfield(n, :idx)) -end - -function count_subtree(tree::ArrayTree, idx::Int8) - count = 1 - degree = @inbounds tree.nodes.degree[idx] - for i in 1:degree - child_idx = @inbounds tree.nodes.children[idx][i] - if child_idx >= 0 - count += count_subtree(tree, child_idx) - end + idx = getfield(n, :idx) + # Optimization: if this is the root of the tree, just return total nodes + if tree.root_idx == idx + return Int(tree.n_nodes) + else + # Fall back to tree_mapreduce for subtrees + return tree_mapreduce(_ -> 1, +, n, Int) end - return count end # Equality and hash -function ==(a::ArrayNode, b::ArrayNode) +function Base.:(==)(a::ArrayNode, b::ArrayNode) a.degree != b.degree && return false if a.degree == 0 @@ -417,7 +443,7 @@ function ==(a::ArrayNode, b::ArrayNode) end end -function hash(n::ArrayNode, h::UInt=zero(UInt)) +function Base.hash(n::ArrayNode, h::UInt=zero(UInt)) if n.degree == 0 if n.constant return hash((0, n.val), h) diff --git a/test/Project.toml b/test/Project.toml index 6ecd3917..4aacc196 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" +FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/test/test_array_node.jl b/test/test_array_node.jl index efa4e0a1..a9d1c870 100644 --- a/test/test_array_node.jl +++ b/test/test_array_node.jl @@ -239,3 +239,66 @@ end expected = vec(X .+ 2.0) # Convert to vector to match result shape @test all(abs.(result .- expected) .< 1e-10) end + +@testitem "ArrayNode with FixedSizeArrays backing storage" begin + using DynamicExpressions + using DynamicExpressions: OperatorEnum + using FixedSizeArrays + using AllocCheck: @check_allocs + const ArrayNode = DynamicExpressions.ArrayNode + + # Create a FixedSizeVector type for our backing storage + # We'll use size 100 for this test + const N = 100 + + # Create an ArrayTree with FixedSizeVector backing + allocator = DynamicExpressions.ArrayNodeModule.ArrayTree{Float64,2}( + N; array_type=FixedSizeVector + ) + + # Test that the backing arrays are indeed FixedSizeArrays + @test allocator.nodes.degree isa FixedSizeArray{UInt8} + @test allocator.nodes.val isa FixedSizeArray{Float64} + @test allocator.nodes.feature isa FixedSizeArray{UInt16} + + # Create nodes using the FixedSizeVector-backed allocator + x1 = ArrayNode{Float64,2}(; feature=1, allocator=allocator) + x2 = ArrayNode{Float64,2}(; feature=2, allocator=allocator) + c = ArrayNode{Float64,2}(; val=3.5, allocator=allocator) + + # Build a tree + mul = ArrayNode{Float64,2}(; op=3, l=x2, r=c, allocator=allocator) + sin_expr = ArrayNode{Float64,2}(; op=1, l=mul, allocator=allocator) + tree = ArrayNode{Float64,2}(; op=1, l=x1, r=sin_expr, allocator=allocator) + + # Test basic operations + @test tree.degree == 2 + @test x1.feature == 1 + @test c.val == 3.5 + + # Test evaluation + operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]) + X = [1.0 2.0; 0.5 1.0] # 2 features, 2 samples + result, complete = eval_tree_array(tree, X, operators) + expected = X[1, :] .+ sin.(X[2, :] .* 3.5) + @test all(abs.(result .- expected) .< 1e-10) + + # Test that operations are still allocation-free + @check_allocs get_degree(n) = n.degree + @check_allocs get_val(n) = n.val + @check_allocs get_feature(n) = n.feature + + get_degree(tree) + get_val(c) + get_feature(x1) + + # Test count_nodes + @test count_nodes(tree) == 6 + + # Test creating nodes is allocation-free with preallocated FixedSizeVector storage + @check_allocs create_node(alloc, f) = ArrayNode{Float64,2}(; feature=f, allocator=alloc) + new_node = create_node(allocator, 5) + @test new_node.feature == 5 + + println("✅ ArrayNode works with FixedSizeArrays backing storage!") +end From e8cb83be8c89f54acb455b46855e8c7fc3660e1c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 29 Aug 2025 22:16:37 -0400 Subject: [PATCH 06/16] test: parity with nodes --- test/test_array_node.jl | 177 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) diff --git a/test/test_array_node.jl b/test/test_array_node.jl index a9d1c870..99a48e83 100644 --- a/test/test_array_node.jl +++ b/test/test_array_node.jl @@ -302,3 +302,180 @@ end println("✅ ArrayNode works with FixedSizeArrays backing storage!") end + +@testitem "ArrayNode vs Node comparison with random trees" begin + using DynamicExpressions + using DynamicExpressions: Node, OperatorEnum + using Random: MersenneTwister + include("tree_gen_utils.jl") + + const ArrayNode = DynamicExpressions.ArrayNode + + # Test with different operator configurations + operators_configs = [ + OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]), + OperatorEnum(; binary_operators=[+, *], unary_operators=[-, abs]), + OperatorEnum(; + binary_operators=[+, -, *, /, ^], unary_operators=[sin, cos, exp, log] + ), + ] + + for operators in operators_configs + rng = MersenneTwister(42) + nfeatures = 3 + + for tree_size in [5, 10, 20] + for _ in 1:5 # Test multiple random trees of each size + # Generate a random Node tree + node_tree = gen_random_tree_fixed_size( + tree_size, operators, nfeatures, Float64, Node, rng + ) + + # Convert to ArrayNode + # First, create an allocator with enough space + allocator = DynamicExpressions.ArrayNodeModule.ArrayTree{Float64,2}( + tree_size * 2 + ) + + # Function to convert Node to ArrayNode + function node_to_array_node(n::Node, alloc) + if n.degree == 0 + if n.constant + return ArrayNode{Float64,2}(; val=n.val, allocator=alloc) + else + return ArrayNode{Float64,2}(; + feature=n.feature, allocator=alloc + ) + end + elseif n.degree == 1 + child = node_to_array_node(n.l, alloc) + return ArrayNode{Float64,2}(; op=n.op, l=child, allocator=alloc) + else # degree == 2 + left = node_to_array_node(n.l, alloc) + right = node_to_array_node(n.r, alloc) + return ArrayNode{Float64,2}(; + op=n.op, l=left, r=right, allocator=alloc + ) + end + end + + array_tree = node_to_array_node(node_tree, allocator) + + # Test 1: Count nodes + @test DynamicExpressions.count_nodes(node_tree) == + DynamicExpressions.count_nodes(array_tree) + + # Test 2: String representation + node_str = DynamicExpressions.string_tree(node_tree, operators) + array_str = DynamicExpressions.string_tree(array_tree, operators) + @test node_str == array_str + + # Test 3: Evaluation on random data + X = randn(rng, nfeatures, 10) + node_result, node_ok = DynamicExpressions.eval_tree_array( + node_tree, X, operators + ) + array_result, array_ok = DynamicExpressions.eval_tree_array( + array_tree, X, operators + ) + + @test node_ok == array_ok + if node_ok && array_ok + # Check that results match (accounting for floating point errors) + @test all(isnan.(node_result) .== isnan.(array_result)) + valid_idx = .!isnan.(node_result) .& .!isnan.(array_result) + if any(valid_idx) + @test all( + abs.(node_result[valid_idx] .- array_result[valid_idx]) .< 1e-10 + ) + end + end + + # Test 4: Hash consistency + # Two equivalent trees should have the same hash + array_tree2 = node_to_array_node(node_tree, allocator) + @test hash(array_tree) == hash(array_tree2) + + # Test 5: Copy operation + array_copy = copy(array_tree) + @test array_copy == array_tree + @test array_copy !== array_tree + @test DynamicExpressions.count_nodes(array_copy) == + DynamicExpressions.count_nodes(array_tree) + end + end + end + + println("✅ ArrayNode matches Node behavior on random trees!") +end + +@testitem "ArrayNode tree_mapreduce operations" begin + using DynamicExpressions + using DynamicExpressions: Node, OperatorEnum, tree_mapreduce + using Random: MersenneTwister + include("tree_gen_utils.jl") + + const ArrayNode = DynamicExpressions.ArrayNode + + operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[sin, -]) + rng = MersenneTwister(123) + nfeatures = 2 + + for tree_size in [5, 10, 15] + # Generate random Node tree + node_tree = gen_random_tree_fixed_size( + tree_size, operators, nfeatures, Float64, Node, rng + ) + + # Convert to ArrayNode + allocator = DynamicExpressions.ArrayNodeModule.ArrayTree{Float64,2}(tree_size * 2) + + function node_to_array_node(n::Node, alloc) + if n.degree == 0 + if n.constant + return ArrayNode{Float64,2}(; val=n.val, allocator=alloc) + else + return ArrayNode{Float64,2}(; feature=n.feature, allocator=alloc) + end + elseif n.degree == 1 + child = node_to_array_node(n.l, alloc) + return ArrayNode{Float64,2}(; op=n.op, l=child, allocator=alloc) + else + left = node_to_array_node(n.l, alloc) + right = node_to_array_node(n.r, alloc) + return ArrayNode{Float64,2}(; op=n.op, l=left, r=right, allocator=alloc) + end + end + + array_tree = node_to_array_node(node_tree, allocator) + + # Test various tree_mapreduce operations + + # 1. Count constants + count_constants = t -> t.constant ? 1 : 0 + node_const_count = tree_mapreduce(count_constants, +, node_tree, Int) + array_const_count = tree_mapreduce(count_constants, +, array_tree, Int) + @test node_const_count == array_const_count + + # 2. Count features + count_features = t -> (!t.constant && t.degree == 0) ? 1 : 0 + node_feat_count = tree_mapreduce(count_features, +, node_tree, Int) + array_feat_count = tree_mapreduce(count_features, +, array_tree, Int) + @test node_feat_count == array_feat_count + + # 3. Max depth + depth_fn = t -> 1 + max_fn = (a, b...) -> maximum((a, b...)) + node_depth = tree_mapreduce(depth_fn, max_fn, node_tree, Int) + array_depth = tree_mapreduce(depth_fn, max_fn, array_tree, Int) + @test node_depth == array_depth + + # 4. Check if any node has specific property + has_sin = t -> (t.degree > 0 && t.op == 1) # Assuming sin is first unary op + node_has_sin = DynamicExpressions.any(has_sin, node_tree) + array_has_sin = DynamicExpressions.any(has_sin, array_tree) + @test node_has_sin == array_has_sin + end + + println("✅ ArrayNode tree_mapreduce operations match Node!") +end From 6044bc3bd0e8857682c29c35d135bb1d534d2ec6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 15:32:13 +0100 Subject: [PATCH 07/16] refactor: use UInt16 for indices in array node --- src/ArrayNode.jl | 98 +++++++++++++++++++++--------------------------- 1 file changed, 42 insertions(+), 56 deletions(-) diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl index 406dfa7e..897159fe 100644 --- a/src/ArrayNode.jl +++ b/src/ArrayNode.jl @@ -29,22 +29,22 @@ struct NodeData{T,D} val::T feature::UInt16 op::UInt8 - children::NTuple{D,Int8} + children::NTuple{D,UInt16} end # Constructor for empty node function NodeData{T,D}() where {T,D} return NodeData{T,D}( - UInt8(0), true, zero(T), UInt16(0), UInt8(0), ntuple(_ -> Int8(-1), Val(D)) + UInt8(0), true, zero(T), UInt16(0), UInt8(0), ntuple(_ -> UInt16(0), Val(D)) ) end mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}} const nodes::S - root_idx::Int8 - n_nodes::Int8 - const free_list::Vector{Int8} - free_count::Int8 + root_idx::UInt16 + n_nodes::UInt16 + const free_list::Vector{UInt16} + free_count::UInt16 function ArrayTree{T,D}(n::Int; array_type::Type{<:AbstractVector}=Vector) where {T,D} # Create backing arrays of the specified type @@ -53,7 +53,7 @@ mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}} val = array_type{T}(undef, n) feature = array_type{UInt16}(undef, n) op = array_type{UInt8}(undef, n) - children = array_type{NTuple{D,Int8}}(undef, n) + children = array_type{NTuple{D,UInt16}}(undef, n) # Create a StructVector from the backing arrays nodes = StructVector{NodeData{T,D}}(( @@ -65,21 +65,11 @@ mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}} children=children, )) - # Initialize all nodes to default values - for i in 1:n - nodes.degree[i] = UInt8(0) - nodes.constant[i] = true - nodes.val[i] = zero(T) - nodes.feature[i] = UInt16(0) - nodes.op[i] = UInt8(0) - nodes.children[i] = ntuple(_ -> Int8(-1), Val(D)) - end - S = typeof(nodes) - tree = new{T,D,S}(nodes, Int8(0), Int8(0), Vector{Int8}(undef, n), Int8(n)) + tree = new{T,D,S}(nodes, UInt16(0), UInt16(0), Vector{UInt16}(undef, n), UInt16(n)) # Initialize free list for i in 1:n - tree.free_list[i] = Int8(i) + tree.free_list[i] = UInt16(i) end return tree end @@ -87,7 +77,7 @@ end struct ArrayNode{T,D,S} <: AbstractExpressionNode{T,D} tree::ArrayTree{T,D,S} - idx::Int8 + idx::UInt16 end function Base.getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} @@ -113,7 +103,7 @@ function Base.getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} # Return tuple of child ArrayNodes wrapped in Nullable return ntuple(Val(D)) do i child_idx = @inbounds nodes.children[idx][i] - if child_idx < 0 + if child_idx == 0 Nullable(true, n) # Poison node else Nullable(false, ArrayNode{T,D,S}(tree, child_idx)) @@ -121,10 +111,10 @@ function Base.getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} end elseif k == :l # Left child for compatibility child_idx = @inbounds nodes.children[idx][1] - return child_idx < 0 ? error("No left child") : ArrayNode{T,D,S}(tree, child_idx) + return child_idx == 0 ? error("No left child") : ArrayNode{T,D,S}(tree, child_idx) elseif k == :r # Right child for compatibility child_idx = @inbounds nodes.children[idx][2] - return child_idx < 0 ? error("No right child") : ArrayNode{T,D,S}(tree, child_idx) + return child_idx == 0 ? error("No right child") : ArrayNode{T,D,S}(tree, child_idx) else error("Unknown field $k") end @@ -168,7 +158,7 @@ function allocate_node!(tree::ArrayTree) return idx end -function free_node!(tree::ArrayTree, idx::Int8) +function free_node!(tree::ArrayTree, idx::UInt16) tree.free_count += 1 tree.free_list[tree.free_count] = idx return tree.n_nodes -= 1 @@ -228,7 +218,6 @@ function ArrayNode{T,D}( end if !isnothing(op) - # DEBUG: op=$op, l is nothing? $(isnothing(l)), r is nothing? $(isnothing(r)) _children = if !isnothing(l) && isnothing(r) (l,) elseif !isnothing(l) && !isnothing(r) @@ -238,7 +227,6 @@ function ArrayNode{T,D}( end if !isnothing(_children) - # DEBUG: Building node with children, length=length(_children) degree = length(_children) tree.nodes.degree[idx] = degree tree.nodes.op[idx] = op @@ -248,7 +236,6 @@ function ArrayNode{T,D}( i -> begin if i <= length(_children) child = _children[i] - # DEBUG: Processing child $i, isa ArrayNode? isa(child, ArrayNode) if isa(child, ArrayNode) child_tree = getfield(child, :tree) child_idx = getfield(child, :idx) @@ -258,21 +245,13 @@ function ArrayNode{T,D}( else # Different tree - copy new_idx = copy_subtree!(tree, child_tree, child_idx) - # DEBUG - # println("Copied child from idx $child_idx to new idx $new_idx") - # println(" Original: constant=", child_tree.nodes.constant[child_idx], - # child_tree.nodes.constant[child_idx] ? ", val=" : ", feature=", - # child_tree.nodes.constant[child_idx] ? child_tree.nodes.val[child_idx] : child_tree.nodes.feature[child_idx]) - # println(" Copied: constant=", tree.nodes.constant[new_idx], - # tree.nodes.constant[new_idx] ? ", val=" : ", feature=", - # tree.nodes.constant[new_idx] ? tree.nodes.val[new_idx] : tree.nodes.feature[new_idx]) new_idx end else - Int8(-1) + UInt16(0) end else - Int8(-1) + UInt16(0) end end, Val(D), @@ -290,7 +269,7 @@ function ArrayNode{T,D}( return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx) end -function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::Int8) where {T,D} +function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::UInt16) where {T,D} dst_idx = allocate_node!(dst) @inbounds begin @@ -306,13 +285,13 @@ function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::Int8) i -> begin if i <= degree child_idx = @inbounds src.nodes.children[src_idx][i] - if child_idx >= 0 + if child_idx > 0 copy_subtree!(dst, src, child_idx) else - Int8(-1) + UInt16(0) end else - Int8(-1) + UInt16(0) end end, Val(D) ) @@ -347,10 +326,10 @@ function unsafe_get_children(n::ArrayNode{T,D,S}) where {T,D,S} return ntuple( i -> begin child_idx = @inbounds tree.nodes.children[idx][i] - if child_idx < 0 + if child_idx == 0 Nullable(true, n) else - Nullable(false, ArrayNode{T,D,typeof(tree.nodes)}(tree, child_idx)) + Nullable(false, ArrayNode{T,D,S}(tree, child_idx)) end end, Val(D), @@ -362,7 +341,7 @@ function get_children(n::ArrayNode{T,D,S}, ::Val{d}) where {T,D,S,d} idx = getfield(n, :idx) return ntuple(i -> begin child_idx = @inbounds tree.nodes.children[idx][i] - ArrayNode{T,D,typeof(tree.nodes)}(tree, child_idx) + ArrayNode{T,D,S}(tree, child_idx) end, Val(Int(d))) end @@ -377,10 +356,10 @@ function set_children!(n::ArrayNode{T,D,S}, cs::Tuple) where {T,D,S} if isa(child, ArrayNode) getfield(child, :idx) else - Int8(-1) + UInt16(0) end else - Int8(-1) + UInt16(0) end end, Val(D)) return tree.nodes.children[idx] = child_indices @@ -396,12 +375,19 @@ function copy_node(n::ArrayNode{T,D,S}; break_sharing::Val{BS}=Val(false)) where # Add some buffer space tree_size = max(32, node_count * 2) - # Create new tree for the copy - new_tree = ArrayTree{T,D}(tree_size) + # Determine the array type from the existing tree's nodes + # Default to Vector since that's the most common case + # For other array types, we'd need more sophisticated type extraction + new_tree = if tree.nodes.degree isa Vector + ArrayTree{T,D}(tree_size; array_type=Vector) + else + # For other array types like FixedSizeVector, we just use default + ArrayTree{T,D}(tree_size) + end new_idx = copy_subtree!(new_tree, tree, idx) new_tree.root_idx = new_idx - return ArrayNode{T,D,typeof(new_tree.nodes)}(new_tree, new_idx) + return ArrayNode{T,D,S}(new_tree, new_idx) end Base.copy(n::ArrayNode) = copy_node(n) @@ -485,7 +471,7 @@ function set_node!(dst::ArrayNode, src::ArrayNode) i -> begin if i <= src.degree child_idx = @inbounds src_tree.nodes.children[src_idx][i] - if child_idx >= 0 + if child_idx > 0 if dst_tree === src_tree # Same tree child_idx @@ -494,10 +480,10 @@ function set_node!(dst::ArrayNode, src::ArrayNode) copy_subtree!(dst_tree, src_tree, child_idx) end else - Int8(-1) + UInt16(0) end else - Int8(-1) + UInt16(0) end end, Val(D), @@ -522,7 +508,7 @@ function tree_mapreduce( return mapreduce_impl(f, op, tree, getfield(n, :idx)) end -function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::Int8) where {F,G,T,D,S} +function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::UInt16) where {F,G,T,D,S} degree = @inbounds tree.nodes.degree[idx] node = ArrayNode{T,D,S}(tree, idx) result = f(node) @@ -531,7 +517,7 @@ function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::Int8) where {F child_results = ntuple( i -> begin child_idx = @inbounds tree.nodes.children[idx][i] - if child_idx >= 0 + if child_idx > 0 mapreduce_impl(f, op, tree, child_idx) else nothing @@ -554,14 +540,14 @@ function any(f::F, n::ArrayNode) where {F<:Function} return any_impl(f, tree, getfield(n, :idx)) end -function any_impl(f::F, tree::ArrayTree{T,D,S}, idx::Int8) where {F,T,D,S} +function any_impl(f::F, tree::ArrayTree{T,D,S}, idx::UInt16) where {F,T,D,S} node = ArrayNode{T,D,S}(tree, idx) f(node) && return true degree = @inbounds tree.nodes.degree[idx] for i in 1:degree child_idx = @inbounds tree.nodes.children[idx][i] - if child_idx >= 0 && any_impl(f, tree, child_idx) + if child_idx > 0 && any_impl(f, tree, child_idx) return true end end From cba7ec0e111e8def392edb64a3532b9d181fd250 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 15:43:25 +0100 Subject: [PATCH 08/16] test: fix check alloc overwrite --- test/test_array_node.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_array_node.jl b/test/test_array_node.jl index 99a48e83..720b66ea 100644 --- a/test/test_array_node.jl +++ b/test/test_array_node.jl @@ -183,7 +183,8 @@ end result = eval_tree_array(tree, X, operators) # Test that evaluation doesn't allocate - @check_allocs eval_tree_array(tree, X, operators) = eval_tree_array(tree, X, operators) + @check_allocs check_eval(t, x, ops) = eval_tree_array(t, x, ops) + check_eval(tree, X, operators) # Test that property access doesn't allocate @check_allocs get_degree(n) = n.degree From 8b8fa485060cc9f43e10f2bca3a81c2d41569bf6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 16:05:47 +0100 Subject: [PATCH 09/16] test: fix count nodes test --- test/test_array_node.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_array_node.jl b/test/test_array_node.jl index 720b66ea..49a4cfbb 100644 --- a/test/test_array_node.jl +++ b/test/test_array_node.jl @@ -198,8 +198,8 @@ end get_op(tree) # Test that count_nodes doesn't allocate (after warm-up) - count_nodes(tree) - @check_allocs count_nodes(tree) = count_nodes(tree) + @check_allocs count_nodes_test(t) = count_nodes(t) + count_nodes_test(tree) # Test that tree traversal doesn't allocate function traverse_tree(n::ArrayNode) From 1ed8721a587155a4bb175405cbef7d64177964b0 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 16:13:10 +0100 Subject: [PATCH 10/16] refactor: small cleanup of array nodes --- src/ArrayNode.jl | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl index 897159fe..b383162d 100644 --- a/src/ArrayNode.jl +++ b/src/ArrayNode.jl @@ -80,7 +80,7 @@ struct ArrayNode{T,D,S} <: AbstractExpressionNode{T,D} idx::UInt16 end -function Base.getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} +@inline function Base.getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) nodes = getfield(tree, :nodes) @@ -109,18 +109,20 @@ function Base.getproperty(n::ArrayNode{T,D,S}, k::Symbol) where {T,D,S} Nullable(false, ArrayNode{T,D,S}(tree, child_idx)) end end - elseif k == :l # Left child for compatibility + elseif k == :l child_idx = @inbounds nodes.children[idx][1] - return child_idx == 0 ? error("No left child") : ArrayNode{T,D,S}(tree, child_idx) - elseif k == :r # Right child for compatibility + child_idx == 0 && error("No left child") + return ArrayNode{T,D,S}(tree, child_idx) + elseif k == :r child_idx = @inbounds nodes.children[idx][2] - return child_idx == 0 ? error("No right child") : ArrayNode{T,D,S}(tree, child_idx) + child_idx == 0 && error("No right child") + return ArrayNode{T,D,S}(tree, child_idx) else error("Unknown field $k") end end -function Base.setproperty!(n::ArrayNode{T,D,S}, k::Symbol, v) where {T,D,S} +@inline function Base.setproperty!(n::ArrayNode{T,D,S}, k::Symbol, v) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) nodes = getfield(tree, :nodes) @@ -136,11 +138,11 @@ function Base.setproperty!(n::ArrayNode{T,D,S}, k::Symbol, v) where {T,D,S} elseif k == :op @inbounds nodes.op[idx] = v elseif k == :l - isa(v, ArrayNode) || error("Cannot set left child to non-ArrayNode") + !isa(v, ArrayNode) && error("Cannot set left child to non-ArrayNode") children = nodes.children[idx] @inbounds nodes.children[idx] = (getfield(v, :idx), children[2:end]...) elseif k == :r - isa(v, ArrayNode) || error("Cannot set right child to non-ArrayNode") + !isa(v, ArrayNode) && error("Cannot set right child to non-ArrayNode") children = nodes.children[idx] @inbounds nodes.children[idx] = (children[1], getfield(v, :idx), children[3:end]...) else @@ -150,7 +152,7 @@ function Base.setproperty!(n::ArrayNode{T,D,S}, k::Symbol, v) where {T,D,S} end # Allocation management -function allocate_node!(tree::ArrayTree) +@inline function allocate_node!(tree::ArrayTree) tree.free_count == 0 && error("ArrayTree full") idx = tree.free_list[tree.free_count] tree.free_count -= 1 @@ -158,7 +160,7 @@ function allocate_node!(tree::ArrayTree) return idx end -function free_node!(tree::ArrayTree, idx::UInt16) +@inline function free_node!(tree::ArrayTree, idx::UInt16) tree.free_count += 1 tree.free_list[tree.free_count] = idx return tree.n_nodes -= 1 @@ -300,16 +302,6 @@ function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::UInt16 return dst_idx end -# Core interface implementations -Base.eltype(::Type{<:ArrayNode{T}}) where {T} = T -Base.eltype(::ArrayNode{T}) where {T} = T - -max_degree(::Type{<:ArrayNode}) = 2 -max_degree(::Type{<:ArrayNode{T,D}}) where {T,D} = D -max_degree(n::ArrayNode) = max_degree(typeof(n)) - -preserve_sharing(::Type{<:ArrayNode}) = false - constructorof(::Type{<:ArrayNode}) = ArrayNode with_type_parameters(::Type{<:ArrayNode}, ::Type{T}) where {T} = ArrayNode{T,2} with_max_degree(::Type{<:ArrayNode{T,D}}, ::Val{D2}) where {T,D,D2} = ArrayNode{T,D2} @@ -508,10 +500,10 @@ function tree_mapreduce( return mapreduce_impl(f, op, tree, getfield(n, :idx)) end -function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::UInt16) where {F,G,T,D,S} +@inline function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::UInt16) where {F,G,T,D,S} degree = @inbounds tree.nodes.degree[idx] node = ArrayNode{T,D,S}(tree, idx) - result = f(node) + result = @inline f(node) if degree > 0 child_results = ntuple( From 8d5b3a76c77e8e5750824f739838699170249fa6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 16:15:31 +0100 Subject: [PATCH 11/16] test: skip alloc check for irrelevant method --- test/test_array_node.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/test_array_node.jl b/test/test_array_node.jl index 49a4cfbb..18c6eaa1 100644 --- a/test/test_array_node.jl +++ b/test/test_array_node.jl @@ -179,12 +179,6 @@ end operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]) X = [1.0 2.0; 0.5 1.0] # 2 features, 2 samples - # Warm up - result = eval_tree_array(tree, X, operators) - - # Test that evaluation doesn't allocate - @check_allocs check_eval(t, x, ops) = eval_tree_array(t, x, ops) - check_eval(tree, X, operators) # Test that property access doesn't allocate @check_allocs get_degree(n) = n.degree From 38ef14a8744ecccdc80e78878c326997a678744f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 16:29:40 +0100 Subject: [PATCH 12/16] refactor: remove redundant parts of interface --- src/ArrayNode.jl | 165 ----------------------------------------------- 1 file changed, 165 deletions(-) diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl index b383162d..dddb7686 100644 --- a/src/ArrayNode.jl +++ b/src/ArrayNode.jl @@ -337,8 +337,6 @@ function get_children(n::ArrayNode{T,D,S}, ::Val{d}) where {T,D,S,d} end, Val(Int(d))) end -get_children(n::ArrayNode, d::Integer) = get_children(n, Val(d)) - function set_children!(n::ArrayNode{T,D,S}, cs::Tuple) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) @@ -384,167 +382,4 @@ end Base.copy(n::ArrayNode) = copy_node(n) -# count_nodes - optimized version that checks if we're at root -function count_nodes(n::ArrayNode) - tree = getfield(n, :tree) - idx = getfield(n, :idx) - # Optimization: if this is the root of the tree, just return total nodes - if tree.root_idx == idx - return Int(tree.n_nodes) - else - # Fall back to tree_mapreduce for subtrees - return tree_mapreduce(_ -> 1, +, n, Int) - end -end - -# Equality and hash -function Base.:(==)(a::ArrayNode, b::ArrayNode) - a.degree != b.degree && return false - - if a.degree == 0 - a.constant != b.constant && return false - if a.constant - return a.val == b.val - else - return a.feature == b.feature - end - else - a.op != b.op && return false - - # Compare children recursively - for i in 1:(a.degree) - ca = get_children(a, Val(Int(a.degree)))[i] - cb = get_children(b, Val(Int(b.degree)))[i] - ca != cb && return false - end - return true - end -end - -function Base.hash(n::ArrayNode, h::UInt=zero(UInt)) - if n.degree == 0 - if n.constant - return hash((0, n.val), h) - else - return hash((1, n.feature), h) - end - else - children_hashes = ntuple( - i -> begin - child = get_children(n, Val(Int(n.degree)))[i] - hash(child, h) - end, Val(Int(n.degree)) - ) - return hash((n.degree + 1, n.op, children_hashes), h) - end -end - -# set_node! implementation -function set_node!(dst::ArrayNode, src::ArrayNode) - dst_tree = getfield(dst, :tree) - src_tree = getfield(src, :tree) - dst_idx = getfield(dst, :idx) - src_idx = getfield(src, :idx) - - dst.degree = src.degree - - if src.degree == 0 - dst.constant = src.constant - if src.constant - dst.val = src.val - else - dst.feature = src.feature - end - else - dst.op = src.op - - D = max_degree(typeof(dst)) - child_indices = ntuple( - i -> begin - if i <= src.degree - child_idx = @inbounds src_tree.nodes.children[src_idx][i] - if child_idx > 0 - if dst_tree === src_tree - # Same tree - child_idx - else - # Different tree - need to copy - copy_subtree!(dst_tree, src_tree, child_idx) - end - else - UInt16(0) - end - else - UInt16(0) - end - end, - Val(D), - ) - dst_tree.nodes.children[dst_idx] = child_indices - end - - return nothing -end - -# tree_mapreduce and any -function tree_mapreduce( - f::F, - op::G, - n::ArrayNode, - (::Type{RT})=Any; - f_on_shared=nothing, - break_sharing=Val(false), - kwargs..., -) where {F<:Function,G<:Function,RT} - tree = getfield(n, :tree) - return mapreduce_impl(f, op, tree, getfield(n, :idx)) -end - -@inline function mapreduce_impl(f::F, op::G, tree::ArrayTree{T,D,S}, idx::UInt16) where {F,G,T,D,S} - degree = @inbounds tree.nodes.degree[idx] - node = ArrayNode{T,D,S}(tree, idx) - result = @inline f(node) - - if degree > 0 - child_results = ntuple( - i -> begin - child_idx = @inbounds tree.nodes.children[idx][i] - if child_idx > 0 - mapreduce_impl(f, op, tree, child_idx) - else - nothing - end - end, Val(Int(degree)) - ) - - # Filter out nothings and apply op - valid_results = filter(x -> !isnothing(x), child_results) - if !isempty(valid_results) - return op(result, valid_results...) - end - end - - return result -end - -function any(f::F, n::ArrayNode) where {F<:Function} - tree = getfield(n, :tree) - return any_impl(f, tree, getfield(n, :idx)) -end - -function any_impl(f::F, tree::ArrayTree{T,D,S}, idx::UInt16) where {F,T,D,S} - node = ArrayNode{T,D,S}(tree, idx) - f(node) && return true - - degree = @inbounds tree.nodes.degree[idx] - for i in 1:degree - child_idx = @inbounds tree.nodes.children[idx][i] - if child_idx > 0 && any_impl(f, tree, child_idx) - return true - end - end - - return false -end - end # module From 5bfb9d72c7f8f7d3c5a6c209632e286d32f57c2d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 16:45:52 +0100 Subject: [PATCH 13/16] fix: `set_children!` for ArrayNode --- src/ArrayNode.jl | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl index dddb7686..e106c527 100644 --- a/src/ArrayNode.jl +++ b/src/ArrayNode.jl @@ -246,8 +246,7 @@ function ArrayNode{T,D}( child_idx else # Different tree - copy - new_idx = copy_subtree!(tree, child_tree, child_idx) - new_idx + copy_subtree!(tree, child_tree, child_idx) end else UInt16(0) @@ -340,19 +339,44 @@ end function set_children!(n::ArrayNode{T,D,S}, cs::Tuple) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) - child_indices = ntuple(i -> begin + child_indices = ntuple(Val(D)) do i if i <= length(cs) child = cs[i] - if isa(child, ArrayNode) - getfield(child, :idx) + if isa(child, Nullable) + # Handle Nullable wrapped children + if child.null + UInt16(0) + else + child_node = child.x + child_tree = getfield(child_node, :tree) + child_idx = getfield(child_node, :idx) + if child_tree === tree + # Same tree - just use the index + child_idx + else + # Different tree - need to copy the subtree + copy_subtree!(tree, child_tree, child_idx) + end + end + elseif isa(child, ArrayNode) + child_tree = getfield(child, :tree) + child_idx = getfield(child, :idx) + if child_tree === tree + # Same tree - just use the index + child_idx + else + # Different tree - need to copy the subtree + copy_subtree!(tree, child_tree, child_idx) + end else UInt16(0) end else UInt16(0) end - end, Val(D)) - return tree.nodes.children[idx] = child_indices + end + tree.nodes.children[idx] = child_indices + return nothing end # Copy From 1e70235d435bdc699d02a986a120dc3a29e20d22 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 17:29:38 +0100 Subject: [PATCH 14/16] refactor: reduce some redundant allocations --- src/ArrayNode.jl | 181 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 135 insertions(+), 46 deletions(-) diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl index e106c527..97de16c8 100644 --- a/src/ArrayNode.jl +++ b/src/ArrayNode.jl @@ -47,28 +47,34 @@ mutable struct ArrayTree{T,D,S<:StructVector{NodeData{T,D}}} free_count::UInt16 function ArrayTree{T,D}(n::Int; array_type::Type{<:AbstractVector}=Vector) where {T,D} - # Create backing arrays of the specified type - degree = array_type{UInt8}(undef, n) - constant = array_type{Bool}(undef, n) - val = array_type{T}(undef, n) - feature = array_type{UInt16}(undef, n) - op = array_type{UInt8}(undef, n) - children = array_type{NTuple{D,UInt16}}(undef, n) - - # Create a StructVector from the backing arrays - nodes = StructVector{NodeData{T,D}}(( - degree=degree, - constant=constant, - val=val, - feature=feature, - op=op, - children=children, - )) + # Create uninitialized StructVector directly + # For custom array types, we'd need to pass them to StructVector somehow + # For now, just use the default + nodes = if array_type === Vector + StructVector{NodeData{T,D}}(undef, n) + else + # For other array types, create backing arrays manually + degree = array_type{UInt8}(undef, n) + constant = array_type{Bool}(undef, n) + val = array_type{T}(undef, n) + feature = array_type{UInt16}(undef, n) + op = array_type{UInt8}(undef, n) + children = array_type{NTuple{D,UInt16}}(undef, n) + StructVector{NodeData{T,D}}(( + degree=degree, + constant=constant, + val=val, + feature=feature, + op=op, + children=children, + )) + end S = typeof(nodes) - tree = new{T,D,S}(nodes, UInt16(0), UInt16(0), Vector{UInt16}(undef, n), UInt16(n)) - # Initialize free list - for i in 1:n + free_list = Vector{UInt16}(undef, n) + tree = new{T,D,S}(nodes, UInt16(0), UInt16(0), free_list, UInt16(n)) + # Initialize free list in-place + @inbounds @simd for i in 1:n tree.free_list[i] = UInt16(i) end return tree @@ -270,7 +276,9 @@ function ArrayNode{T,D}( return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx) end -function copy_subtree!(dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::UInt16) where {T,D} +function copy_subtree!( + dst::ArrayTree{T,D}, src::ArrayTree{T,D}, src_idx::UInt16 +) where {T,D} dst_idx = allocate_node!(dst) @inbounds begin @@ -314,17 +322,14 @@ end function unsafe_get_children(n::ArrayNode{T,D,S}) where {T,D,S} tree = getfield(n, :tree) idx = getfield(n, :idx) - return ntuple( - i -> begin - child_idx = @inbounds tree.nodes.children[idx][i] - if child_idx == 0 - Nullable(true, n) - else - Nullable(false, ArrayNode{T,D,S}(tree, child_idx)) - end - end, - Val(D), - ) + return ntuple(i -> begin + child_idx = @inbounds tree.nodes.children[idx][i] + if child_idx == 0 + Nullable(true, n) + else + Nullable(false, ArrayNode{T,D,S}(tree, child_idx)) + end + end, Val(D)) end function get_children(n::ArrayNode{T,D,S}, ::Val{d}) where {T,D,S,d} @@ -379,31 +384,115 @@ function set_children!(n::ArrayNode{T,D,S}, cs::Tuple) where {T,D,S} return nothing end +# Helper to mark nodes as reachable from a given root +function mark_reachable!( + reachable::Vector{Bool}, tree::ArrayTree{T,D}, idx::UInt16 +) where {T,D} + if idx == 0 || reachable[idx] + return nothing + end + reachable[idx] = true + degree = @inbounds tree.nodes.degree[idx] + for i in 1:degree + child_idx = @inbounds tree.nodes.children[idx][i] + if child_idx != 0 + mark_reachable!(reachable, tree, child_idx) + end + end +end + # Copy +# Note: break_sharing parameter is ignored since ArrayNode doesn't preserve sharing function copy_node(n::ArrayNode{T,D,S}; break_sharing::Val{BS}=Val(false)) where {T,D,S,BS} + # BS parameter unused - ArrayNode always breaks sharing since each node owns its tree tree = getfield(n, :tree) idx = getfield(n, :idx) + n_capacity = length(tree.nodes) - # Count nodes to determine tree size needed - node_count = count_nodes(n) - # Add some buffer space - tree_size = max(32, node_count * 2) - - # Determine the array type from the existing tree's nodes - # Default to Vector since that's the most common case - # For other array types, we'd need more sophisticated type extraction + # Create new tree with same capacity new_tree = if tree.nodes.degree isa Vector - ArrayTree{T,D}(tree_size; array_type=Vector) + ArrayTree{T,D}(n_capacity; array_type=Vector) else - # For other array types like FixedSizeVector, we just use default - ArrayTree{T,D}(tree_size) + ArrayTree{T,D}(n_capacity) + end + + # Direct array copy - works for both full tree and subtree + new_tree.nodes.degree[:] = tree.nodes.degree + new_tree.nodes.constant[:] = tree.nodes.constant + new_tree.nodes.val[:] = tree.nodes.val + new_tree.nodes.feature[:] = tree.nodes.feature + new_tree.nodes.op[:] = tree.nodes.op + new_tree.nodes.children[:] = tree.nodes.children + + # Set the root to our copied node + new_tree.root_idx = idx + + if idx == tree.root_idx + # Full tree copy - just copy all metadata + new_tree.n_nodes = tree.n_nodes + new_tree.free_count = tree.free_count + new_tree.free_list[:] = tree.free_list + else + # Subtree copy - need to update free list to exclude unreachable nodes + reachable = fill(false, n_capacity) + mark_reachable!(reachable, new_tree, idx) + + # Reset free list with unreachable nodes + new_tree.free_count = 0 + new_tree.n_nodes = 0 + for i in 1:n_capacity + if !reachable[i] + new_tree.free_count += 1 + new_tree.free_list[new_tree.free_count] = UInt16(i) + else + new_tree.n_nodes += 1 + end + end end - new_idx = copy_subtree!(new_tree, tree, idx) - new_tree.root_idx = new_idx - return ArrayNode{T,D,S}(new_tree, new_idx) + return ArrayNode{T,D,S}(new_tree, new_tree.root_idx) end Base.copy(n::ArrayNode) = copy_node(n) +# tree_mapreduce implementation +function tree_mapreduce( + f::F, + op::G, + tree::ArrayNode{T,D,S}, + result_type::Type{RT}=Undefined; + f_on_shared::H=(result, is_shared) -> result, + break_sharing::Val{BS}=Val(false), +) where {F<:Function,G<:Function,H<:Function,T,D,S,RT,BS} + return tree_mapreduce(f, f, op, tree, result_type; f_on_shared, break_sharing) +end + +function tree_mapreduce( + f_leaf::F1, + f_branch::F2, + op::G, + tree::ArrayNode{T,D,S}, + result_type::Type{RT}=Undefined; + f_on_shared::H=(result, is_shared) -> result, + break_sharing::Val{BS}=Val(false), +) where {F1<:Function,F2<:Function,G<:Function,H<:Function,T,D,S,RT,BS} + # ArrayNode doesn't preserve sharing, so we can use simple recursion + if tree.degree == 0 + return f_leaf(tree) + else + # Apply to children + degree = tree.degree + children_results = ntuple(Val(Int(degree))) do i + child = get_children(tree, Val(degree))[i] + tree_mapreduce( + f_leaf, f_branch, op, child, result_type; f_on_shared, break_sharing + ) + end + + # Reduce children results + self_result = f_branch(tree) + return op(self_result, children_results...) + end +end + end # module From 8e179ec6cf3432699b7c7c3e2fe1ea172496798c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 18:01:24 +0100 Subject: [PATCH 15/16] refactor: remove more unneccessary methods --- src/ArrayNode.jl | 51 +++++------------------------------------------- 1 file changed, 5 insertions(+), 46 deletions(-) diff --git a/src/ArrayNode.jl b/src/ArrayNode.jl index 97de16c8..2195a4e1 100644 --- a/src/ArrayNode.jl +++ b/src/ArrayNode.jl @@ -8,16 +8,10 @@ import ..NodeModule: constructorof, with_type_parameters, with_max_degree, - preserve_sharing, - max_degree, default_allocator, get_children, set_children!, unsafe_get_children, - tree_mapreduce, - count_nodes, - set_node!, - any, copy_node export ArrayNode @@ -215,6 +209,8 @@ function ArrayNode{T,D}( tree.nodes.degree[idx] = 0 tree.nodes.constant[idx] = true tree.nodes.val[idx] = val + # Clear children for leaf node + tree.nodes.children[idx] = ntuple(_ -> UInt16(0), Val(D)) return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx) end @@ -222,6 +218,8 @@ function ArrayNode{T,D}( tree.nodes.degree[idx] = 0 tree.nodes.constant[idx] = false tree.nodes.feature[idx] = feature + # Clear children for leaf node + tree.nodes.children[idx] = ntuple(_ -> UInt16(0), Val(D)) return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx) end @@ -273,6 +271,7 @@ function ArrayNode{T,D}( tree.nodes.degree[idx] = 0 tree.nodes.constant[idx] = true tree.nodes.val[idx] = zero(T) + tree.nodes.children[idx] = ntuple(_ -> UInt16(0), Val(D)) return ArrayNode{T,D,typeof(tree.nodes)}(tree, idx) end @@ -455,44 +454,4 @@ end Base.copy(n::ArrayNode) = copy_node(n) -# tree_mapreduce implementation -function tree_mapreduce( - f::F, - op::G, - tree::ArrayNode{T,D,S}, - result_type::Type{RT}=Undefined; - f_on_shared::H=(result, is_shared) -> result, - break_sharing::Val{BS}=Val(false), -) where {F<:Function,G<:Function,H<:Function,T,D,S,RT,BS} - return tree_mapreduce(f, f, op, tree, result_type; f_on_shared, break_sharing) -end - -function tree_mapreduce( - f_leaf::F1, - f_branch::F2, - op::G, - tree::ArrayNode{T,D,S}, - result_type::Type{RT}=Undefined; - f_on_shared::H=(result, is_shared) -> result, - break_sharing::Val{BS}=Val(false), -) where {F1<:Function,F2<:Function,G<:Function,H<:Function,T,D,S,RT,BS} - # ArrayNode doesn't preserve sharing, so we can use simple recursion - if tree.degree == 0 - return f_leaf(tree) - else - # Apply to children - degree = tree.degree - children_results = ntuple(Val(Int(degree))) do i - child = get_children(tree, Val(degree))[i] - tree_mapreduce( - f_leaf, f_branch, op, child, result_type; f_on_shared, break_sharing - ) - end - - # Reduce children results - self_result = f_branch(tree) - return op(self_result, children_results...) - end -end - end # module From 76e9a5ba2717184d183c0081f06eeb3ded515c8a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 30 Aug 2025 18:01:38 +0100 Subject: [PATCH 16/16] test: ensure no aliasing --- test/test_array_node.jl | 56 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/test/test_array_node.jl b/test/test_array_node.jl index 18c6eaa1..6a734bcd 100644 --- a/test/test_array_node.jl +++ b/test/test_array_node.jl @@ -179,7 +179,6 @@ end operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[sin, cos]) X = [1.0 2.0; 0.5 1.0] # 2 features, 2 samples - # Test that property access doesn't allocate @check_allocs get_degree(n) = n.degree @check_allocs get_val(n) = n.val @@ -474,3 +473,58 @@ end println("✅ ArrayNode tree_mapreduce operations match Node!") end + +@testitem "ArrayNode copy has no array aliasing" begin + using DynamicExpressions + const ArrayNode = DynamicExpressions.ArrayNode + + # Create a test tree + x1 = ArrayNode{Float64,2,Vector}(; feature=1) + x2 = ArrayNode{Float64,2,Vector}(; feature=2) + tree = ArrayNode{Float64,2,Vector}(; + op=1, + l=ArrayNode{Float64,2,Vector}(; + op=2, l=x1, r=ArrayNode{Float64,2,Vector}(; val=3.5) + ), + r=x2, + ) + + # Test 1: Copy entire tree (root node) + tree_copy = copy(tree) + + # Verify no aliasing - modifying copy shouldn't affect original + tree_copy.val = 999.0 + tree_copy.l.val = 888.0 + + # Check that original is unchanged + @test tree.l.r.val == 3.5 # Original value unchanged + @test tree.l.r.val != 888.0 + + # Verify the backing arrays are different + orig_tree = tree.tree + copy_tree = tree_copy.tree + @test orig_tree !== copy_tree # Different tree objects + @test orig_tree.nodes.val !== copy_tree.nodes.val # Different arrays + @test orig_tree.nodes.degree !== copy_tree.nodes.degree + @test orig_tree.nodes.children !== copy_tree.nodes.children + + # Test 2: Copy subtree (non-root node) + subtree = tree.l + subtree_copy = copy(subtree) + + # Modify the copy + subtree_copy.r.val = 777.0 + + # Original should be unchanged + @test tree.l.r.val == 3.5 + @test subtree.r.val == 3.5 + + # Verify different backing arrays for subtree copy too + subtree_copy_tree = subtree_copy.tree + @test orig_tree !== subtree_copy_tree + @test orig_tree.nodes.val !== subtree_copy_tree.nodes.val + + # Test 3: Verify structure is preserved in copy + @test copy(tree) == tree + @test copy(subtree) == subtree +end