@@ -2,12 +2,12 @@ module OperatorEnumConstructionModule
2
2
3
3
import Zygote: gradient
4
4
import .. UtilsModule: max_ops
5
- import .. OperatorEnumModule: OperatorEnum, GenericOperatorEnum
5
+ import .. OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
6
6
import .. EquationModule: string_tree, Node
7
7
import .. EvaluateEquationModule: eval_tree_array
8
8
import .. EvaluateEquationDerivativeModule: eval_grad_tree_array
9
9
10
- function create_evaluation_helper_functions (operators:: OperatorEnum )
10
+ function create_evaluation_helpers! (operators:: OperatorEnum )
11
11
@eval begin
12
12
Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
13
13
Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
@@ -37,7 +37,7 @@ function create_evaluation_helper_functions(operators::OperatorEnum)
37
37
end
38
38
end
39
39
40
- function create_evaluation_helper_functions (operators:: GenericOperatorEnum )
40
+ function create_evaluation_helpers! (operators:: GenericOperatorEnum )
41
41
@eval begin
42
42
Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
43
43
Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
@@ -54,11 +54,14 @@ function create_evaluation_helper_functions(operators::GenericOperatorEnum)
54
54
end
55
55
end
56
56
57
- function create_node_helper_functions (
57
+ function create_construction_helpers! (
58
58
operators:: AbstractOperatorEnum ; extend_user_operators:: Bool = false
59
59
)
60
- for (op, f) in enumerate (map (Symbol, binary_operators))
61
- if typeof (operators) <: OperatorEnum
60
+ is_scalar_operator_enum = typeof (operators) <: OperatorEnum
61
+ type_requirements = is_scalar_operator_enum ? Real : Any
62
+
63
+ for (op, f) in enumerate (map (Symbol, operators. binops))
64
+ if is_scalar_operator_enum
62
65
f = if f in [:pow , :safe_pow ]
63
66
Symbol (^ )
64
67
else
@@ -74,7 +77,9 @@ function create_node_helper_functions(
74
77
Base. MainInclude. eval (
75
78
quote
76
79
import DynamicExpressions: Node
77
- function $f (l:: Node{T1} , r:: Node{T2} ) where {T1<: Real ,T2<: Real }
80
+ function $f (
81
+ l:: Node{T1} , r:: Node{T2}
82
+ ) where {T1<: $type_requirements ,T2<: $type_requirements }
78
83
T = promote_type (T1, T2)
79
84
l = convert (Node{T}, l)
80
85
r = convert (Node{T}, r)
@@ -84,7 +89,9 @@ function create_node_helper_functions(
84
89
return Node ($ op, l, r)
85
90
end
86
91
end
87
- function $f (l:: Node{T1} , r:: T2 ) where {T1<: Real ,T2<: Real }
92
+ function $f (
93
+ l:: Node{T1} , r:: T2
94
+ ) where {T1<: $type_requirements ,T2<: $type_requirements }
88
95
T = promote_type (T1, T2)
89
96
l = convert (Node{T}, l)
90
97
r = convert (T, r)
@@ -94,7 +101,9 @@ function create_node_helper_functions(
94
101
Node ($ op, l, Node (; val= r))
95
102
end
96
103
end
97
- function $f (l:: T1 , r:: Node{T2} ) where {T1<: Real ,T2<: Real }
104
+ function $f (
105
+ l:: T1 , r:: Node{T2}
106
+ ) where {T1<: $type_requirements ,T2<: $type_requirements }
98
107
T = promote_type (T1, T2)
99
108
l = convert (T, l)
100
109
r = convert (Node{T}, r)
@@ -108,7 +117,7 @@ function create_node_helper_functions(
108
117
)
109
118
end
110
119
# Redefine Base operations:
111
- for (op, f) in enumerate (map (Symbol, unary_operators ))
120
+ for (op, f) in enumerate (map (Symbol, operators . unaops ))
112
121
if isdefined (Base, f)
113
122
f = :(Base.$ (f))
114
123
elseif ! extend_user_operators
@@ -118,7 +127,7 @@ function create_node_helper_functions(
118
127
Base. MainInclude. eval (
119
128
quote
120
129
import DynamicExpressions: Node
121
- function $f (l:: Node{T} ):: Node{T} where {T<: Real }
130
+ function $f (l:: Node{T} ):: Node{T} where {T<: $type_requirements }
122
131
return l. constant ? Node (; val= $ f (l. val)) : Node ($ op, l)
123
132
end
124
133
end ,
@@ -209,8 +218,8 @@ function OperatorEnum(;
209
218
)
210
219
211
220
if define_helper_functions
212
- create_node_helper_functions (operators; extend_user_operators= extend_user_operators)
213
- create_evaluation_helper_functions (operators)
221
+ create_construction_helpers! (operators; extend_user_operators= extend_user_operators)
222
+ create_evaluation_helpers! (operators)
214
223
end
215
224
216
225
return operators
@@ -249,8 +258,8 @@ function GenericOperatorEnum(;
249
258
operators = GenericOperatorEnum (binary_operators, unary_operators)
250
259
251
260
if define_helper_functions
252
- create_node_helper_functions (operators; extend_user_operators= extend_user_operators)
253
- create_evaluation_helper_functions (operators)
261
+ create_construction_helpers! (operators; extend_user_operators= extend_user_operators)
262
+ create_evaluation_helpers! (operators)
254
263
end
255
264
256
265
return operators
0 commit comments