Skip to content

Commit ebe61f9

Browse files
committed
Fix helper function generators
1 parent 061c198 commit ebe61f9

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

src/OperatorEnumConstruction.jl

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ module OperatorEnumConstructionModule
22

33
import Zygote: gradient
44
import ..UtilsModule: max_ops
5-
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
5+
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
66
import ..EquationModule: string_tree, Node
77
import ..EvaluateEquationModule: eval_tree_array
88
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
99

10-
function create_evaluation_helper_functions(operators::OperatorEnum)
10+
function create_evaluation_helpers!(operators::OperatorEnum)
1111
@eval begin
1212
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
1313
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
@@ -37,7 +37,7 @@ function create_evaluation_helper_functions(operators::OperatorEnum)
3737
end
3838
end
3939

40-
function create_evaluation_helper_functions(operators::GenericOperatorEnum)
40+
function create_evaluation_helpers!(operators::GenericOperatorEnum)
4141
@eval begin
4242
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
4343
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
@@ -54,11 +54,14 @@ function create_evaluation_helper_functions(operators::GenericOperatorEnum)
5454
end
5555
end
5656

57-
function create_node_helper_functions(
57+
function create_construction_helpers!(
5858
operators::AbstractOperatorEnum; extend_user_operators::Bool=false
5959
)
60-
for (op, f) in enumerate(map(Symbol, binary_operators))
61-
if typeof(operators) <: OperatorEnum
60+
is_scalar_operator_enum = typeof(operators) <: OperatorEnum
61+
type_requirements = is_scalar_operator_enum ? Real : Any
62+
63+
for (op, f) in enumerate(map(Symbol, operators.binops))
64+
if is_scalar_operator_enum
6265
f = if f in [:pow, :safe_pow]
6366
Symbol(^)
6467
else
@@ -74,7 +77,9 @@ function create_node_helper_functions(
7477
Base.MainInclude.eval(
7578
quote
7679
import DynamicExpressions: Node
77-
function $f(l::Node{T1}, r::Node{T2}) where {T1<:Real,T2<:Real}
80+
function $f(
81+
l::Node{T1}, r::Node{T2}
82+
) where {T1<:$type_requirements,T2<:$type_requirements}
7883
T = promote_type(T1, T2)
7984
l = convert(Node{T}, l)
8085
r = convert(Node{T}, r)
@@ -84,7 +89,9 @@ function create_node_helper_functions(
8489
return Node($op, l, r)
8590
end
8691
end
87-
function $f(l::Node{T1}, r::T2) where {T1<:Real,T2<:Real}
92+
function $f(
93+
l::Node{T1}, r::T2
94+
) where {T1<:$type_requirements,T2<:$type_requirements}
8895
T = promote_type(T1, T2)
8996
l = convert(Node{T}, l)
9097
r = convert(T, r)
@@ -94,7 +101,9 @@ function create_node_helper_functions(
94101
Node($op, l, Node(; val=r))
95102
end
96103
end
97-
function $f(l::T1, r::Node{T2}) where {T1<:Real,T2<:Real}
104+
function $f(
105+
l::T1, r::Node{T2}
106+
) where {T1<:$type_requirements,T2<:$type_requirements}
98107
T = promote_type(T1, T2)
99108
l = convert(T, l)
100109
r = convert(Node{T}, r)
@@ -108,7 +117,7 @@ function create_node_helper_functions(
108117
)
109118
end
110119
# Redefine Base operations:
111-
for (op, f) in enumerate(map(Symbol, unary_operators))
120+
for (op, f) in enumerate(map(Symbol, operators.unaops))
112121
if isdefined(Base, f)
113122
f = :(Base.$(f))
114123
elseif !extend_user_operators
@@ -118,7 +127,7 @@ function create_node_helper_functions(
118127
Base.MainInclude.eval(
119128
quote
120129
import DynamicExpressions: Node
121-
function $f(l::Node{T})::Node{T} where {T<:Real}
130+
function $f(l::Node{T})::Node{T} where {T<:$type_requirements}
122131
return l.constant ? Node(; val=$f(l.val)) : Node($op, l)
123132
end
124133
end,
@@ -209,8 +218,8 @@ function OperatorEnum(;
209218
)
210219

211220
if define_helper_functions
212-
create_node_helper_functions(operators; extend_user_operators=extend_user_operators)
213-
create_evaluation_helper_functions(operators)
221+
create_construction_helpers!(operators; extend_user_operators=extend_user_operators)
222+
create_evaluation_helpers!(operators)
214223
end
215224

216225
return operators
@@ -249,8 +258,8 @@ function GenericOperatorEnum(;
249258
operators = GenericOperatorEnum(binary_operators, unary_operators)
250259

251260
if define_helper_functions
252-
create_node_helper_functions(operators; extend_user_operators=extend_user_operators)
253-
create_evaluation_helper_functions(operators)
261+
create_construction_helpers!(operators; extend_user_operators=extend_user_operators)
262+
create_evaluation_helpers!(operators)
254263
end
255264

256265
return operators

0 commit comments

Comments
 (0)