@@ -5,8 +5,7 @@ using DynamicExpressions:
5
5
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array, EvalOptions
6
6
using DynamicExpressions. UtilsModule: ResultOk, counttuple
7
7
8
- import DynamicExpressions. ExtensionInterfaceModule:
9
- bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
8
+ import DynamicExpressions. ExtensionInterfaceModule: bumper_eval_tree_array, bumper_kern!
10
9
11
10
function bumper_eval_tree_array (
12
11
tree:: AbstractExpressionNode{T} ,
@@ -37,8 +36,7 @@ function bumper_eval_tree_array(
37
36
branch_node -> branch_node,
38
37
# In the evaluation kernel, we combine the branch nodes
39
38
# with the arrays created by the leaf nodes:
40
- ((args:: Vararg{Any,M} ) where {M}) ->
41
- dispatch_kerns! (operators, args... , eval_options),
39
+ KernelDispatcher (operators, eval_options),
42
40
tree;
43
41
break_sharing= Val (true ),
44
42
)
@@ -49,63 +47,44 @@ function bumper_eval_tree_array(
49
47
return (result, all_ok[])
50
48
end
51
49
52
- function dispatch_kerns! (
53
- operators, branch_node, cumulator, eval_options:: EvalOptions{<:Any,true,early_exit}
54
- ) where {early_exit}
55
- cumulator. ok || return cumulator
56
-
57
- out = dispatch_kern1! (operators. unaops, branch_node. op, cumulator. x, eval_options)
58
- return ResultOk (out, early_exit ? is_valid_array (out) : true )
59
- end
60
- function dispatch_kerns! (
61
- operators,
62
- branch_node,
63
- cumulator1,
64
- cumulator2,
65
- eval_options:: EvalOptions{<:Any,true,early_exit} ,
66
- ) where {early_exit}
67
- cumulator1. ok || return cumulator1
68
- cumulator2. ok || return cumulator2
69
-
70
- out = dispatch_kern2! (
71
- operators. binops, branch_node. op, cumulator1. x, cumulator2. x, eval_options
72
- )
73
- return ResultOk (out, early_exit ? is_valid_array (out) : true )
50
+ struct KernelDispatcher{O<: OperatorEnum ,E<: EvalOptions{<:Any,true,<:Any} } <: Function
51
+ operators:: O
52
+ eval_options:: E
74
53
end
75
54
76
- @generated function dispatch_kern1! (unaops, op_idx, cumulator, eval_options:: EvalOptions )
77
- nuna = counttuple (unaops)
55
+ @generated function (kd:: KernelDispatcher{<:Any,<:EvalOptions{<:Any,true,early_exit}} )(
56
+ branch_node, inputs:: Vararg{Any,degree}
57
+ ) where {degree,early_exit}
78
58
quote
79
- Base. @nif (
80
- $ nuna,
81
- i -> i == op_idx,
82
- i -> let op = unaops[i]
83
- return bumper_kern1! (op, cumulator, eval_options)
84
- end ,
85
- )
59
+ Base. Cartesian. @nexprs ($ degree, i -> inputs[i]. ok || return inputs[i])
60
+ cumulators = Base. Cartesian. @ntuple ($ degree, i -> inputs[i]. x)
61
+ out = dispatch_kerns! (kd. operators, branch_node, cumulators, kd. eval_options)
62
+ return ResultOk (out, early_exit ? is_valid_array (out) : true )
86
63
end
87
64
end
88
- @generated function dispatch_kern2! (
89
- binops, op_idx, cumulator1, cumulator2, eval_options:: EvalOptions
90
- )
91
- nbin = counttuple (binops)
65
+ @generated function dispatch_kerns! (
66
+ operators:: OperatorEnum{OPS} ,
67
+ branch_node,
68
+ cumulators:: Tuple{Vararg{Any,degree}} ,
69
+ eval_options:: EvalOptions ,
70
+ ) where {OPS,degree}
71
+ nops = length (OPS. types[degree]. types)
92
72
quote
93
- Base. @nif (
94
- $ nbin,
73
+ op_idx = branch_node. op
74
+ Base. Cartesian. @nif (
75
+ $ nops,
95
76
i -> i == op_idx,
96
- i -> let op = binops[i]
97
- return bumper_kern2! (op, cumulator1, cumulator2, eval_options)
98
- end ,
77
+ i -> bumper_kern! (operators[$ degree][i], cumulators, eval_options)
99
78
)
100
79
end
101
80
end
102
- function bumper_kern1! (op :: F , cumulator, :: EvalOptions{false,true} ) where {F}
103
- @. cumulator = op (cumulator)
104
- return cumulator
105
- end
106
- function bumper_kern2! (op :: F , cumulator1, cumulator2, :: EvalOptions{false,true} ) where {F}
107
- @. cumulator1 = op (cumulator1, cumulator2 )
108
- return cumulator1
81
+
82
+ function bumper_kern! (
83
+ op :: F , cumulators :: Tuple{Vararg{Any,degree}} , :: EvalOptions{false,true,early_exit}
84
+ ) where {F,degree,early_exit}
85
+ cumulator_1 = first (cumulators)
86
+ @. cumulator_1 = op (cumulators ... )
87
+ return cumulator_1
109
88
end
110
89
111
90
end
0 commit comments