Skip to content

Commit 292deb9

Browse files
authored
Merge pull request #5 from SymbolicML/explicit-errors
By default, throw errors from `MethodError`
2 parents 70bd87e + 8702a47 commit 292deb9

File tree

4 files changed

+107
-17
lines changed

4 files changed

+107
-17
lines changed

src/EvaluateEquation.jl

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module EvaluateEquationModule
22

3-
import ..EquationModule: Node
3+
import ..EquationModule: Node, string_tree
44
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
55
import ..UtilsModule: @return_on_false, is_bad_array, vals
66
import ..EquationUtilsModule: is_constant
@@ -47,8 +47,12 @@ function eval(current_node)
4747
The bulk of the code is for optimizations and pre-emptive NaN/Inf checks,
4848
which speed up evaluation significantly.
4949
50-
# Returns
50+
# Arguments
51+
- `tree::Node`: The root node of the tree to evaluate.
52+
- `cX::AbstractMatrix{T}`: The input data to evaluate the tree on.
53+
- `operators::OperatorEnum`: The operators used in the tree.
5154
55+
# Returns
5256
- `(output, complete)::Tuple{AbstractVector{T}, Bool}`: the result,
5357
which is a 1D array, as well as if the evaluation completed
5458
successfully (true/false). A `false` complete means an infinity
@@ -461,19 +465,51 @@ function eval(current_node)
461465
return current_node.operator(eval(current_node.left_child), eval(current_node.right_child))
462466
```
463467
464-
468+
# Arguments
469+
- `tree::Node`: The root node of the tree to evaluate.
470+
- `cX::AbstractArray{T,N}`: The input data to evaluate the tree on.
471+
- `operators::GenericOperatorEnum`: The operators used in the tree.
472+
- `throw_errors::Bool=true`: Whether to throw errors
473+
if they occur during evaluation. Otherwise,
474+
MethodErrors will be caught before they happen and
475+
evaluation will return `nothing`,
476+
rather than throwing an error. This is useful in cases
477+
where you are unsure if a particular tree is valid or not,
478+
and would prefer to work with `nothing` as an output.
465479
466480
# Returns
467-
468481
- `(output, complete)::Tuple{Any, Bool}`: the result,
469482
as well as if the evaluation completed successfully (true/false).
470483
If evaluation failed, `nothing` will be returned for the first argument.
471484
A `false` complete means an operator was called on input types
472485
that it was not defined for.
473486
"""
474487
function eval_tree_array(
475-
tree::Node{T1}, cX::AbstractArray{T2,N}, operators::GenericOperatorEnum
476-
) where {T1,T2,N}
488+
tree::Node, cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true
489+
)
490+
!throw_errors && return _eval_tree_array(tree, cX, operators, Val(false))
491+
try
492+
return _eval_tree_array(tree, cX, operators, Val(true))
493+
catch e
494+
tree_s = string_tree(tree, operators)
495+
error_msg = "Failed to evaluate tree $(tree_s)."
496+
if isa(e, MethodError)
497+
error_msg *= (
498+
" Note that you can efficiently skip MethodErrors" *
499+
" beforehand by passing `throw_errors=false` to " *
500+
" `eval_tree_array`."
501+
)
502+
end
503+
throw(ErrorException(error_msg))
504+
end
505+
end
506+
507+
function _eval_tree_array(
508+
tree::Node{T1},
509+
cX::AbstractArray{T2,N},
510+
operators::GenericOperatorEnum,
511+
::Val{throw_errors},
512+
) where {T1,T2,N,throw_errors}
477513
if tree.degree == 0
478514
if tree.constant
479515
return (tree.val::T1), true
@@ -485,27 +521,33 @@ function eval_tree_array(
485521
end
486522
end
487523
elseif tree.degree == 1
488-
return deg1_eval(tree, cX, vals[tree.op], operators)
524+
return deg1_eval(tree, cX, vals[tree.op], operators, Val(throw_errors))
489525
else
490-
return deg2_eval(tree, cX, vals[tree.op], operators)
526+
return deg2_eval(tree, cX, vals[tree.op], operators, Val(throw_errors))
491527
end
492528
end
493529

494-
function deg1_eval(tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum) where {op_idx}
530+
function deg1_eval(
531+
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
532+
) where {op_idx,throw_errors}
495533
left, complete = eval_tree_array(tree.l, cX, operators)
496-
!complete && return nothing, false
534+
!throw_errors && !complete && return nothing, false
497535
op = operators.unaops[op_idx]
498-
!hasmethod(op, Tuple{typeof(left)}) && return nothing, false
536+
!throw_errors && !hasmethod(op, Tuple{typeof(left)}) && return nothing, false
499537
return op(left), true
500538
end
501539

502-
function deg2_eval(tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum) where {op_idx}
540+
function deg2_eval(
541+
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
542+
) where {op_idx,throw_errors}
503543
left, complete = eval_tree_array(tree.l, cX, operators)
504-
!complete && return nothing, false
544+
!throw_errors && !complete && return nothing, false
505545
right, complete = eval_tree_array(tree.r, cX, operators)
506-
!complete && return nothing, false
546+
!throw_errors && !complete && return nothing, false
507547
op = operators.binops[op_idx]
508-
!hasmethod(op, Tuple{typeof(left),typeof(right)}) && return nothing, false
548+
!throw_errors &&
549+
!hasmethod(op, Tuple{typeof(left),typeof(right)}) &&
550+
return nothing, false
509551
return op(left, right), true
510552
end
511553

src/OperatorEnumConstruction.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,10 @@ function GenericOperatorEnum(;
283283
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
284284
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
285285

286-
function (tree::Node)(X)
287-
out, did_finish = eval_tree_array(tree, X, $operators)
286+
function (tree::Node)(X; throw_errors::Bool=true)
287+
out, did_finish = eval_tree_array(
288+
tree, X, $operators; throw_errors=throw_errors
289+
)
288290
if !did_finish
289291
return nothing
290292
end

test/test_error_handling.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using DynamicExpressions
2+
using Test
3+
4+
# Test that we generate errors:
5+
baseT = Float64
6+
T = Union{baseT,Vector{baseT},Matrix{baseT}}
7+
8+
scalar_add(x::T, y::T) where {T<:Real} = x + y
9+
10+
operators = GenericOperatorEnum(; binary_operators=[scalar_add], extend_user_operators=true)
11+
12+
x1, x2, x3 = [Node(T; feature=i) for i in 1:3]
13+
14+
tree = Node(1, x1, x2)
15+
16+
# With error handling:
17+
try
18+
eval_tree_array(tree, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], operators; throw_errors=true)
19+
@test false
20+
catch e
21+
@test isa(e, ErrorException)
22+
expected_error_msg = "Failed to evaluate tree"
23+
@test occursin(expected_error_msg, e.msg)
24+
end
25+
26+
# Without error handling:
27+
output, flag = eval_tree_array(
28+
tree, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], operators; throw_errors=false
29+
)
30+
@test output === nothing
31+
@test !flag
32+
33+
# Default is to catch errors:
34+
try
35+
tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
36+
@test false
37+
catch e
38+
@test isa(e, ErrorException)
39+
end
40+
41+
# But can be overrided:
42+
output = tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; throw_errors=false)

test/unittest.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,7 @@ end
5151
@safetestset "Test tensor operators" begin
5252
include("test_tensor_operators.jl")
5353
end
54+
55+
@safetestset "Test error handling" begin
56+
include("test_error_handling.jl")
57+
end

0 commit comments

Comments
 (0)