|
| 1 | +module DynamicExpressionsBumperExt |
| 2 | + |
| 3 | +using Bumper: @no_escape, @alloc |
| 4 | +using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce |
| 5 | +using DynamicExpressions.UtilsModule: ResultOk, counttuple, is_bad_array |
| 6 | + |
| 7 | +import DynamicExpressions.ExtensionInterfaceModule: |
| 8 | + bumper_eval_tree_array, bumper_kern1!, bumper_kern2! |
| 9 | + |
| 10 | +function bumper_eval_tree_array( |
| 11 | + tree::AbstractExpressionNode{T}, |
| 12 | + cX::AbstractMatrix{T}, |
| 13 | + operators::OperatorEnum, |
| 14 | + ::Val{turbo}, |
| 15 | +) where {T,turbo} |
| 16 | + result = similar(cX, axes(cX, 2)) |
| 17 | + n = size(cX, 2) |
| 18 | + all_ok = Ref(false) |
| 19 | + @no_escape begin |
| 20 | + _result_ok = tree_mapreduce( |
| 21 | + # Leaf nodes, we create an allocation and fill |
| 22 | + # it with the value of the leaf: |
| 23 | + leaf_node -> begin |
| 24 | + ar = @alloc(T, n) |
| 25 | + ok = if leaf_node.constant |
| 26 | + v = leaf_node.val::T |
| 27 | + ar .= v |
| 28 | + isfinite(v) |
| 29 | + else |
| 30 | + ar .= view(cX, leaf_node.feature, :) |
| 31 | + true |
| 32 | + end |
| 33 | + ResultOk(ar, ok) |
| 34 | + end, |
| 35 | + # Branch nodes, we simply pass them to the evaluation kernel: |
| 36 | + branch_node -> branch_node, |
| 37 | + # In the evaluation kernel, we combine the branch nodes |
| 38 | + # with the arrays created by the leaf nodes: |
| 39 | + ((args::Vararg{Any,M}) where {M}) -> |
| 40 | + dispatch_kerns!(operators, args..., Val(turbo)), |
| 41 | + tree; |
| 42 | + break_sharing=Val(true), |
| 43 | + ) |
| 44 | + x = _result_ok.x |
| 45 | + result .= x |
| 46 | + all_ok[] = _result_ok.ok |
| 47 | + end |
| 48 | + return (result, all_ok[]) |
| 49 | +end |
| 50 | + |
| 51 | +function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where {turbo} |
| 52 | + cumulator.ok || return cumulator |
| 53 | + |
| 54 | + out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo)) |
| 55 | + return ResultOk(out, !is_bad_array(out)) |
| 56 | +end |
| 57 | +function dispatch_kerns!( |
| 58 | + operators, branch_node, cumulator1, cumulator2, ::Val{turbo} |
| 59 | +) where {turbo} |
| 60 | + cumulator1.ok || return cumulator1 |
| 61 | + cumulator2.ok || return cumulator2 |
| 62 | + |
| 63 | + out = dispatch_kern2!( |
| 64 | + operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo) |
| 65 | + ) |
| 66 | + return ResultOk(out, !is_bad_array(out)) |
| 67 | +end |
| 68 | + |
| 69 | +@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo} |
| 70 | + nuna = counttuple(unaops) |
| 71 | + quote |
| 72 | + Base.@nif( |
| 73 | + $nuna, |
| 74 | + i -> i == op_idx, |
| 75 | + i -> let op = unaops[i] |
| 76 | + return bumper_kern1!(op, cumulator, Val(turbo)) |
| 77 | + end, |
| 78 | + ) |
| 79 | + end |
| 80 | +end |
| 81 | +@generated function dispatch_kern2!( |
| 82 | + binops, op_idx, cumulator1, cumulator2, ::Val{turbo} |
| 83 | +) where {turbo} |
| 84 | + nbin = counttuple(binops) |
| 85 | + quote |
| 86 | + Base.@nif( |
| 87 | + $nbin, |
| 88 | + i -> i == op_idx, |
| 89 | + i -> let op = binops[i] |
| 90 | + return bumper_kern2!(op, cumulator1, cumulator2, Val(turbo)) |
| 91 | + end, |
| 92 | + ) |
| 93 | + end |
| 94 | +end |
| 95 | +function bumper_kern1!(op::F, cumulator, ::Val{false}) where {F} |
| 96 | + @. cumulator = op(cumulator) |
| 97 | + return cumulator |
| 98 | +end |
| 99 | +function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{false}) where {F} |
| 100 | + @. cumulator1 = op(cumulator1, cumulator2) |
| 101 | + return cumulator1 |
| 102 | +end |
| 103 | + |
| 104 | +end |
0 commit comments