Skip to content

Commit 03ee2c5

Browse files
authored
Merge pull request #16 from SymbolicML/complex-numbers
Open OperatorEnum to complex numbers
2 parents 34271cb + e53f306 commit 03ee2c5

File tree

9 files changed

+52
-50
lines changed

9 files changed

+52
-50
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.4.2"
4+
version = "0.4.3"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

docs/src/eval.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Given an expression tree specified with a `Node` type, you may evaluate the expr
44
over an array of data with the following command:
55

66
```@docs
7-
eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Real}
7+
eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Number}
88
```
99

1010
Assuming you are only using a single `OperatorEnum`, you can also use
@@ -45,15 +45,15 @@ all variables (or, all constants). Both use forward-mode automatic, but use
4545
`Zygote.jl` to compute derivatives of each operator, so this is very efficient.
4646

4747
```@docs
48-
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int) where {T<:Real}
49-
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false) where {T<:Real}
48+
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int) where {T<:Number}
49+
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; variable::Bool=false) where {T<:Number}
5050
```
5151

5252
Alternatively, you can compute higher-order derivatives by using `ForwardDiff` on
5353
the function `differentiable_eval_tree_array`, although this will be slower.
5454

5555
```@docs
56-
differentiable_eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Real}
56+
differentiable_eval_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum) where {T<:Number}
5757
```
5858

5959
## Printing

docs/src/types.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Construct this operator specification as follows:
1616
OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, define_helper_functions::Bool=true)
1717
```
1818

19-
This is just for scalar real operators. However, you can use
19+
This is just for scalar operators. However, you can use
2020
the following for more general operators:
2121

2222
```@docs

src/EvaluateEquation.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ which speed up evaluation significantly.
6363
"""
6464
function eval_tree_array(
6565
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false
66-
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
66+
)::Tuple{AbstractVector{T},Bool} where {T<:Number}
6767
n = size(cX, 2)
6868
if turbo
6969
@assert T in (Float32, Float64)
@@ -77,7 +77,7 @@ function eval_tree_array(
7777
end
7878
function eval_tree_array(
7979
tree::Node{T1}, cX::AbstractMatrix{T2}, operators::OperatorEnum; turbo::Bool=false
80-
) where {T1<:Real,T2<:Real}
80+
) where {T1<:Number,T2<:Number}
8181
T = promote_type(T1, T2)
8282
@warn "Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2)."
8383
tree = convert(Node{T}, tree)
@@ -87,7 +87,7 @@ end
8787

8888
function _eval_tree_array(
8989
tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, ::Val{turbo}
90-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,turbo}
90+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,turbo}
9191
n = size(cX, 2)
9292
# First, we see if there are only constants in the tree - meaning
9393
# we can just return the constant result.
@@ -148,7 +148,7 @@ end
148148

149149
function deg2_eval(
150150
cumulator_l::AbstractVector{T}, cumulator_r::AbstractVector{T}, op::F, ::Val{turbo}
151-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
151+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
152152
@maybe_turbo turbo for j in indices(cumulator_l)
153153
x = op(cumulator_l[j], cumulator_r[j])::T
154154
cumulator_l[j] = x
@@ -158,7 +158,7 @@ end
158158

159159
function deg1_eval(
160160
cumulator::AbstractVector{T}, op::F, ::Val{turbo}
161-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
161+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
162162
@maybe_turbo turbo for j in indices(cumulator)
163163
x = op(cumulator[j])::T
164164
cumulator[j] = x
@@ -168,7 +168,7 @@ end
168168

169169
function deg0_eval(
170170
tree::Node{T}, cX::AbstractMatrix{T}
171-
)::Tuple{AbstractVector{T},Bool} where {T<:Real}
171+
)::Tuple{AbstractVector{T},Bool} where {T<:Number}
172172
if tree.constant
173173
n = size(cX, 2)
174174
return (fill(tree.val::T, n), true)
@@ -179,7 +179,7 @@ end
179179

180180
function deg1_l2_ll0_lr0_eval(
181181
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
182-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,F2,turbo}
182+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,F2,turbo}
183183
n = size(cX, 2)
184184
if tree.l.l.constant && tree.l.r.constant
185185
val_ll = tree.l.l.val::T
@@ -229,7 +229,7 @@ end
229229
# op(op2(x)) for x variable or constant
230230
function deg1_l1_ll0_eval(
231231
tree::Node{T}, cX::AbstractMatrix{T}, op::F, op_l::F2, ::Val{turbo}
232-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,F2,turbo}
232+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,F2,turbo}
233233
n = size(cX, 2)
234234
if tree.l.l.constant
235235
val_ll = tree.l.l.val::T
@@ -254,7 +254,7 @@ end
254254
# op(x, y) for x and y variable/constant
255255
function deg2_l0_r0_eval(
256256
tree::Node{T}, cX::AbstractMatrix{T}, op::F, ::Val{turbo}
257-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
257+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
258258
n = size(cX, 2)
259259
if tree.l.constant && tree.r.constant
260260
val_l = tree.l.val::T
@@ -297,7 +297,7 @@ end
297297
# op(x, y) for x variable/constant, y arbitrary
298298
function deg2_l0_eval(
299299
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
300-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
300+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
301301
n = size(cX, 2)
302302
if tree.l.constant
303303
val = tree.l.val::T
@@ -319,7 +319,7 @@ end
319319
# op(x, y) for x arbitrary, y variable/constant
320320
function deg2_r0_eval(
321321
tree::Node{T}, cumulator::AbstractVector{T}, cX::AbstractArray{T}, op::F, ::Val{turbo}
322-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,turbo}
322+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,turbo}
323323
n = size(cX, 2)
324324
if tree.r.constant
325325
val = tree.r.val::T
@@ -339,15 +339,15 @@ function deg2_r0_eval(
339339
end
340340

341341
"""
342-
_eval_constant_tree(tree::Node{T}, operators::OperatorEnum)::Tuple{T,Bool} where {T<:Real}
342+
_eval_constant_tree(tree::Node{T}, operators::OperatorEnum)::Tuple{T,Bool} where {T<:Number}
343343
344344
Evaluate a tree which is assumed to not contain any variable nodes. This
345345
gives better performance, as we do not need to perform computation
346346
over an entire array when the values are all the same.
347347
"""
348348
function _eval_constant_tree(
349349
tree::Node{T}, operators::OperatorEnum
350-
)::Tuple{T,Bool} where {T<:Real}
350+
)::Tuple{T,Bool} where {T<:Number}
351351
if tree.degree == 0
352352
return deg0_eval_constant(tree)
353353
elseif tree.degree == 1
@@ -357,13 +357,13 @@ function _eval_constant_tree(
357357
end
358358
end
359359

360-
@inline function deg0_eval_constant(tree::Node{T})::Tuple{T,Bool} where {T<:Real}
360+
@inline function deg0_eval_constant(tree::Node{T})::Tuple{T,Bool} where {T<:Number}
361361
return tree.val::T, true
362362
end
363363

364364
function deg1_eval_constant(
365365
tree::Node{T}, op::F, operators::OperatorEnum
366-
)::Tuple{T,Bool} where {T<:Real,F}
366+
)::Tuple{T,Bool} where {T<:Number,F}
367367
(cumulator, complete) = _eval_constant_tree(tree.l, operators)
368368
!complete && return zero(T), false
369369
output = op(cumulator)::T
@@ -372,7 +372,7 @@ end
372372

373373
function deg2_eval_constant(
374374
tree::Node{T}, op::F, operators::OperatorEnum
375-
)::Tuple{T,Bool} where {T<:Real,F}
375+
)::Tuple{T,Bool} where {T<:Number,F}
376376
(cumulator, complete) = _eval_constant_tree(tree.l, operators)
377377
!complete && return zero(T), false
378378
(cumulator2, complete2) = _eval_constant_tree(tree.r, operators)
@@ -388,7 +388,7 @@ Evaluate an expression tree in a way that can be auto-differentiated.
388388
"""
389389
function differentiable_eval_tree_array(
390390
tree::Node{T1}, cX::AbstractMatrix{T}, operators::OperatorEnum
391-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,T1}
391+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,T1}
392392
n = size(cX, 2)
393393
if tree.degree == 0
394394
if tree.constant
@@ -405,7 +405,7 @@ end
405405

406406
function deg1_diff_eval(
407407
tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
408-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,T1}
408+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,T1}
409409
(left, complete) = differentiable_eval_tree_array(tree.l, cX, operators)
410410
@return_on_false complete left
411411
out = op.(left)
@@ -415,7 +415,7 @@ end
415415

416416
function deg2_diff_eval(
417417
tree::Node{T1}, cX::AbstractMatrix{T}, op::F, operators::OperatorEnum
418-
)::Tuple{AbstractVector{T},Bool} where {T<:Real,F,T1}
418+
)::Tuple{AbstractVector{T},Bool} where {T<:Number,F,T1}
419419
(left, complete) = differentiable_eval_tree_array(tree.l, cX, operators)
420420
@return_on_false complete left
421421
(right, complete2) = differentiable_eval_tree_array(tree.r, cX, operators)

src/EvaluateEquationDerivative.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function eval_diff_tree_array(
4343
operators::OperatorEnum,
4444
direction::Int;
4545
turbo::Bool=false,
46-
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real}
46+
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number}
4747
assert_autodiff_enabled(operators)
4848
# TODO: Implement quick check for whether the variable is actually used
4949
# in this tree. Otherwise, return zero.
@@ -57,7 +57,7 @@ function eval_diff_tree_array(
5757
operators::OperatorEnum,
5858
direction::Int;
5959
turbo::Bool=false,
60-
) where {T1<:Real,T2<:Real}
60+
) where {T1<:Number,T2<:Number}
6161
T = promote_type(T1, T2)
6262
@warn "Warning: eval_diff_tree_array received mixed types: tree=$(T1) and data=$(T2)."
6363
tree = convert(Node{T}, tree)
@@ -71,7 +71,7 @@ function _eval_diff_tree_array(
7171
operators::OperatorEnum,
7272
direction::Int,
7373
::Val{turbo},
74-
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,turbo}
74+
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,turbo}
7575
evaluation, derivative, complete = if tree.degree == 0
7676
diff_deg0_eval(tree, cX, direction)
7777
elseif tree.degree == 1
@@ -101,7 +101,7 @@ end
101101

102102
function diff_deg0_eval(
103103
tree::Node{T}, cX::AbstractMatrix{T}, direction::Int
104-
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real}
104+
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number}
105105
n = size(cX, 2)
106106
const_part = deg0_eval(tree, cX)[1]
107107
derivative_part =
@@ -117,7 +117,7 @@ function diff_deg1_eval(
117117
operators::OperatorEnum,
118118
direction::Int,
119119
::Val{turbo},
120-
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,F,dF,turbo}
120+
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,F,dF,turbo}
121121
n = size(cX, 2)
122122
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
123123
tree.l, cX, operators, direction, Val(turbo)
@@ -143,7 +143,7 @@ function diff_deg2_eval(
143143
operators::OperatorEnum,
144144
direction::Int,
145145
::Val{turbo},
146-
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Real,F,dF,turbo}
146+
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,F,dF,turbo}
147147
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
148148
tree.l, cX, operators, direction, Val(turbo)
149149
)
@@ -194,7 +194,7 @@ function eval_grad_tree_array(
194194
operators::OperatorEnum;
195195
variable::Bool=false,
196196
turbo::Bool=false,
197-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real}
197+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number}
198198
assert_autodiff_enabled(operators)
199199
n = size(cX, 2)
200200
if variable
@@ -224,7 +224,7 @@ function eval_grad_tree_array(
224224
operators::OperatorEnum,
225225
::Val{variable},
226226
::Val{turbo},
227-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable,turbo}
227+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable,turbo}
228228
evaluation, gradient, complete = _eval_grad_tree_array(
229229
tree, n, n_gradients, index_tree, cX, operators, Val(variable), Val(turbo)
230230
)
@@ -238,7 +238,7 @@ function eval_grad_tree_array(
238238
operators::OperatorEnum;
239239
variable::Bool=false,
240240
turbo::Bool=false,
241-
) where {T1<:Real,T2<:Real}
241+
) where {T1<:Number,T2<:Number}
242242
T = promote_type(T1, T2)
243243
return eval_grad_tree_array(
244244
convert(Node{T}, tree),
@@ -258,7 +258,7 @@ function _eval_grad_tree_array(
258258
operators::OperatorEnum,
259259
::Val{variable},
260260
::Val{turbo},
261-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable,turbo}
261+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable,turbo}
262262
if tree.degree == 0
263263
grad_deg0_eval(tree, n, n_gradients, index_tree, cX, Val(variable))
264264
elseif tree.degree == 1
@@ -297,7 +297,7 @@ function grad_deg0_eval(
297297
index_tree::NodeIndex,
298298
cX::AbstractMatrix{T},
299299
::Val{variable},
300-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,variable}
300+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable}
301301
const_part = deg0_eval(tree, cX)[1]
302302

303303
if variable == tree.constant
@@ -321,7 +321,7 @@ function grad_deg1_eval(
321321
operators::OperatorEnum,
322322
::Val{variable},
323323
::Val{turbo},
324-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,F,dF,variable,turbo}
324+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,F,dF,variable,turbo}
325325
(cumulator, dcumulator, complete) = eval_grad_tree_array(
326326
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable), Val(turbo)
327327
)
@@ -350,7 +350,7 @@ function grad_deg2_eval(
350350
operators::OperatorEnum,
351351
::Val{variable},
352352
::Val{turbo},
353-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Real,F,dF,variable,turbo}
353+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,F,dF,variable,turbo}
354354
(cumulator1, dcumulator1, complete) = eval_grad_tree_array(
355355
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable), Val(turbo)
356356
)

src/OperatorEnum.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ abstract type AbstractOperatorEnum end
77
88
Defines an enum over operators, along with their derivatives.
99
# Fields
10-
- `binops`: A tuple of binary operators. Real scalar input type.
11-
- `unaops`: A tuple of unary operators. Real scalar input type.
10+
- `binops`: A tuple of binary operators. Scalar input type.
11+
- `unaops`: A tuple of unary operators. Scalar input type.
1212
- `diff_binops`: A tuple of Zygote-computed derivatives of the binary operators.
1313
- `diff_unaops`: A tuple of Zygote-computed derivatives of the unary operators.
1414
"""

src/OperatorEnumConstruction.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
126126
local type_requirements
127127
local build_converters
128128
if isa($operators, OperatorEnum)
129-
type_requirements = Real
129+
type_requirements = Number
130130
build_converters = true
131131
else
132132
type_requirements = Any
@@ -257,9 +257,9 @@ and `(::Node)(X)`.
257257
258258
# Arguments
259259
- `binary_operators::Vector{Function}`: A vector of functions, each of which is a binary
260-
operator on real scalars.
260+
operator.
261261
- `unary_operators::Vector{Function}`: A vector of functions, each of which is a unary
262-
operator on real scalars.
262+
operator.
263263
- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
264264
and evaluating node types. Turn this off when doing precompilation. Note that these
265265
are *not* needed for the package to work; they are purely for convenience.

src/SimplifyEquation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function combine_operators(
99
tree::Node{T},
1010
operators::AbstractOperatorEnum,
1111
id_map::IdDict{Node{T},Node{T}}=IdDict{Node{T},Node{T}}(),
12-
)::Node{T} where {T<:Real}
12+
)::Node{T} where {T}
1313
# NOTE: (const (+*-) const) already accounted for. Call simplify_tree before.
1414
# ((const + var) + const) => (const + var)
1515
# ((const * var) * const) => (const * var)

0 commit comments

Comments
 (0)