Skip to content

Commit f20cc38

Browse files
authored
Merge pull request #33 from SymbolicML/preserve-container
Preserve container types
2 parents a6fe9b1 + 924a831 commit f20cc38

File tree

6 files changed

+128
-98
lines changed

6 files changed

+128
-98
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ julia = "1.6"
3030
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3131
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3232
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
33+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3334
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3435

3536
[targets]
36-
test = ["Test", "SafeTestsets", "SpecialFunctions", "ForwardDiff"]
37+
test = ["Test", "SafeTestsets", "SpecialFunctions", "ForwardDiff", "StaticArrays"]

src/EvaluateEquation.jl

Lines changed: 42 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,21 @@ module EvaluateEquationModule
33
import LoopVectorization: @turbo, indices
44
import ..EquationModule: Node, string_tree
55
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
6-
import ..UtilsModule: @return_on_false, @maybe_turbo, is_bad_array
6+
import ..UtilsModule: @return_on_false, @maybe_turbo, is_bad_array, fill_similar
77
import ..EquationUtilsModule: is_constant
88

9-
macro return_on_check(val, T, n)
10-
# This will generate the following code:
11-
# if !isfinite(val)
12-
# return (Array{T, 1}(undef, n), false)
13-
# end
14-
9+
macro return_on_check(val, X)
1510
:(
1611
if !isfinite($(esc(val)))
17-
return (Array{$(esc(T)),1}(undef, $(esc(n))), false)
12+
return (similar($(esc(X)), axes($(esc(X)), 2)), false)
1813
end
1914
)
2015
end
2116

22-
macro return_on_nonfinite_array(array, T, n)
17+
macro return_on_nonfinite_array(array)
2318
:(
2419
if is_bad_array($(esc(array)))
25-
return (Array{$(esc(T)),1}(undef, $(esc(n))), false)
20+
return ($(esc(array)), false)
2621
end
2722
)
2823
end
@@ -64,15 +59,14 @@ which speed up evaluation significantly.
6459
function eval_tree_array(
6560
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false
6661
)::Tuple{AbstractVector{T},Bool} where {T<:Number}
67-
n = size(cX, 2)
6862
if turbo
6963
@assert T in (Float32, Float64)
7064
end
7165
result, finished = _eval_tree_array(
7266
tree, cX, operators, (turbo ? Val(true) : Val(false))
7367
)
7468
@return_on_false finished result
75-
@return_on_nonfinite_array result T n
69+
@return_on_nonfinite_array result
7670
return result, finished
7771
end
7872
function eval_tree_array(
@@ -81,23 +75,22 @@ function eval_tree_array(
8175
T = promote_type(T1, T2)
8276
@warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)."
8377
tree = convert(Node{T}, tree)
84-
cX = convert(AbstractMatrix{T}, cX)
78+
cX = T.(cX)
8579
return eval_tree_array(tree, cX, operators; turbo=turbo)
8680
end
8781

8882
function _eval_tree_array(
8983
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo}
9084
)::Tuple{AbstractVector{T},Bool} where {T<:Number,turbo}
91-
n = size(cX, 2)
9285
# First, we see if there are only constants in the tree - meaning
9386
# we can just return the constant result.
9487
if tree.degree == 0
9588
return deg0_eval(tree, cX)
9689
elseif is_constant(tree)
9790
# Speed hack for constant trees.
9891
result, flag = _eval_constant_tree(tree, operators)
99-
!flag && return Array{T,1}(undef, size(cX, 2)), false
100-
return fill(result, size(cX, 2)), true
92+
!flag && return similar(cX, axes(cX, 2)), false
93+
return fill_similar(result, cX, axes(cX, 2)), true
10194
elseif tree.degree == 1
10295
op = operators.unaops[tree.op]
10396
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
@@ -113,7 +106,7 @@ function _eval_tree_array(
113106
# op(x), for any x.
114107
(cumulator, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
115108
@return_on_false complete cumulator
116-
@return_on_nonfinite_array cumulator T n
109+
@return_on_nonfinite_array cumulator
117110
return deg1_eval(cumulator, op, Val(turbo))
118111

119112
elseif tree.degree == 2
@@ -125,22 +118,22 @@ function _eval_tree_array(
125118
elseif tree.r.degree == 0
126119
(cumulator_l, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
127120
@return_on_false complete cumulator_l
128-
@return_on_nonfinite_array cumulator_l T n
121+
@return_on_nonfinite_array cumulator_l
129122
# op(x, y), where y is a constant or variable but x is not.
130123
return deg2_r0_eval(tree, cumulator_l, cX, op, Val(turbo))
131124
elseif tree.l.degree == 0
132125
(cumulator_r, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
133126
@return_on_false complete cumulator_r
134-
@return_on_nonfinite_array cumulator_r T n
127+
@return_on_nonfinite_array cumulator_r
135128
# op(x, y), where x is a constant or variable but y is not.
136129
return deg2_l0_eval(tree, cumulator_r, cX, op, Val(turbo))
137130
end
138131
(cumulator_l, complete) = _eval_tree_array(tree.l, cX, operators, Val(turbo))
139132
@return_on_false complete cumulator_l
140-
@return_on_nonfinite_array cumulator_l T n
133+
@return_on_nonfinite_array cumulator_l
141134
(cumulator_r, complete) = _eval_tree_array(tree.r, cX, operators, Val(turbo))
142135
@return_on_false complete cumulator_r
143-
@return_on_nonfinite_array cumulator_r T n
136+
@return_on_nonfinite_array cumulator_r
144137
# op(x, y), for any x or y
145138
return deg2_eval(cumulator_l, cumulator_r, op, Val(turbo))
146139
end
@@ -170,8 +163,7 @@ function deg0_eval(
170163
tree::Node{T}, cX::AbstractMatrix{T}
171164
)::Tuple{AbstractVector{T},Bool} where {T<:Number}
172165
if tree.constant
173-
n = size(cX, 2)
174-
return (fill(tree.val::T, n), true)
166+
return (fill_similar(tree.val::T, cX, axes(cX, 2)), true)
175167
else
176168
return (cX[tree.feature, :], true)
177169
end
@@ -180,22 +172,21 @@ end
180172
function deg1_l2_ll0_lr0_eval(
181173
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
182174
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,F2,turbo}
183-
n = size(cX, 2)
184175
if tree.l.l.constant && tree.l.r.constant
185176
val_ll = tree.l.l.val::T
186177
val_lr = tree.l.r.val::T
187-
@return_on_check val_ll T n
188-
@return_on_check val_lr T n
178+
@return_on_check val_ll cX
179+
@return_on_check val_lr cX
189180
x_l = op_l(val_ll, val_lr)::T
190-
@return_on_check x_l T n
181+
@return_on_check x_l cX
191182
x = op(x_l)::T
192-
@return_on_check x T n
193-
return (fill(x, n), true)
183+
@return_on_check x cX
184+
return (fill_similar(x, cX, axes(cX, 2)), true)
194185
elseif tree.l.l.constant
195186
val_ll = tree.l.l.val::T
196-
@return_on_check val_ll T n
187+
@return_on_check val_ll cX
197188
feature_lr = tree.l.r.feature
198-
cumulator = Array{T,1}(undef, n)
189+
cumulator = similar(cX, axes(cX, 2))
199190
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
200191
x_l = op_l(val_ll, cX[feature_lr, j])::T
201192
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
@@ -205,8 +196,8 @@ function deg1_l2_ll0_lr0_eval(
205196
elseif tree.l.r.constant
206197
feature_ll = tree.l.l.feature
207198
val_lr = tree.l.r.val::T
208-
@return_on_check val_lr T n
209-
cumulator = Array{T,1}(undef, n)
199+
@return_on_check val_lr cX
200+
cumulator = similar(cX, axes(cX, 2))
210201
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
211202
x_l = op_l(cX[feature_ll, j], val_lr)::T
212203
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
@@ -216,7 +207,7 @@ function deg1_l2_ll0_lr0_eval(
216207
else
217208
feature_ll = tree.l.l.feature
218209
feature_lr = tree.l.r.feature
219-
cumulator = Array{T,1}(undef, n)
210+
cumulator = similar(cX, axes(cX, 2))
220211
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
221212
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
222213
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
@@ -230,18 +221,17 @@ end
230221
function deg1_l1_ll0_eval(
231222
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
232223
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,F2,turbo}
233-
n = size(cX, 2)
234224
if tree.l.l.constant
235225
val_ll = tree.l.l.val::T
236-
@return_on_check val_ll T n
226+
@return_on_check val_ll cX
237227
x_l = op_l(val_ll)::T
238-
@return_on_check x_l T n
228+
@return_on_check x_l cX
239229
x = op(x_l)::T
240-
@return_on_check x T n
241-
return (fill(x, n), true)
230+
@return_on_check x cX
231+
return (fill_similar(x, cX, axes(cX, 2)), true)
242232
else
243233
feature_ll = tree.l.l.feature
244-
cumulator = Array{T,1}(undef, n)
234+
cumulator = similar(cX, axes(cX, 2))
245235
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
246236
x_l = op_l(cX[feature_ll, j])::T
247237
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
@@ -255,35 +245,34 @@ end
255245
function deg2_l0_r0_eval(
256246
tree::Node{T}, cX::AbstractMatrix{T}, op::F, ::Val{turbo}
257247
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
258-
n = size(cX, 2)
259248
if tree.l.constant && tree.r.constant
260249
val_l = tree.l.val::T
261-
@return_on_check val_l T n
250+
@return_on_check val_l cX
262251
val_r = tree.r.val::T
263-
@return_on_check val_r T n
252+
@return_on_check val_r cX
264253
x = op(val_l, val_r)::T
265-
@return_on_check x T n
266-
return (fill(x, n), true)
254+
@return_on_check x cX
255+
return (fill_similar(x, cX, axes(cX, 2)), true)
267256
elseif tree.l.constant
268-
cumulator = Array{T,1}(undef, n)
257+
cumulator = similar(cX, axes(cX, 2))
269258
val_l = tree.l.val::T
270-
@return_on_check val_l T n
259+
@return_on_check val_l cX
271260
feature_r = tree.r.feature
272261
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
273262
x = op(val_l, cX[feature_r, j])::T
274263
cumulator[j] = x
275264
end
276265
elseif tree.r.constant
277-
cumulator = Array{T,1}(undef, n)
266+
cumulator = similar(cX, axes(cX, 2))
278267
feature_l = tree.l.feature
279268
val_r = tree.r.val::T
280-
@return_on_check val_r T n
269+
@return_on_check val_r cX
281270
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
282271
x = op(cX[feature_l, j], val_r)::T
283272
cumulator[j] = x
284273
end
285274
else
286-
cumulator = Array{T,1}(undef, n)
275+
cumulator = similar(cX, axes(cX, 2))
287276
feature_l = tree.l.feature
288277
feature_r = tree.r.feature
289278
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
@@ -298,10 +287,9 @@ end
298287
function deg2_l0_eval(
299288
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
300289
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
301-
n = size(cX, 2)
302290
if tree.l.constant
303291
val = tree.l.val::T
304-
@return_on_check val T n
292+
@return_on_check val cX
305293
@maybe_turbo turbo for j in indices(cumulator)
306294
x = op(val, cumulator[j])::T
307295
cumulator[j] = x
@@ -320,10 +308,9 @@ end
320308
function deg2_r0_eval(
321309
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
322310
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
323-
n = size(cX, 2)
324311
if tree.r.constant
325312
val = tree.r.val::T
326-
@return_on_check val T n
313+
@return_on_check val cX
327314
@maybe_turbo turbo for j in indices(cumulator)
328315
x = op(cumulator[j], val)::T
329316
cumulator[j] = x
@@ -389,10 +376,9 @@ Evaluate an expression tree in a way that can be auto-differentiated.
389376
function differentiable_eval_tree_array(
390377
tree::Node{T1}, cX::AbstractMatrix{T}, operators::OperatorEnum
391378
)::Tuple{AbstractVector{T},Bool} where {T<:Number,T1}
392-
n = size(cX, 2)
393379
if tree.degree == 0
394380
if tree.constant
395-
return (ones(T, n) .* convert(T, tree.val), true)
381+
return (fill_similar(one(T), cX, axes(cX, 2)) .* tree.val, true)
396382
else
397383
return (cX[tree.feature, :], true)
398384
end

0 commit comments

Comments
 (0)