Skip to content

Commit a36ccde

Browse files
francescobriviothesps
authored andcommitted
implement balanced tree reduce for xilinxhls backend
1 parent 271401c commit a36ccde

File tree

3 files changed

+44
-14
lines changed

3 files changed

+44
-14
lines changed

conifer/backends/xilinxhls/firmware/BDT_unrolled.h

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,37 @@
55

66
namespace BDT{
77

8+
/* ---
9+
* Balanced tree reduce implementation.
10+
* Reduces an array of inputs to a single value using the template binary operator 'Op',
11+
* for example summing all elements with OpAdd, or finding the maximum with OpMax
12+
* Use only when the input array is fully unrolled. Or, slice out a fully unrolled section
13+
* before applying and accumulate the result over the rolled dimension.
14+
* Required for emulation to guarantee equality of ordering.
15+
* --- */
16+
constexpr int floorlog2(int x) { return (x < 2) ? 0 : 1 + floorlog2(x / 2); }
17+
18+
constexpr int pow2(int x) { return x == 0 ? 1 : 2 * pow2(x - 1); }
19+
20+
template <class T, int N, class Op> T reduce(const T *x, Op op) {
21+
static constexpr int leftN = pow2(floorlog2(N - 1)) > 0 ? pow2(floorlog2(N - 1)) : 0;
22+
static constexpr int rightN = N - leftN > 0 ? N - leftN : 0;
23+
if (N == 1) {
24+
return x[0];
25+
}
26+
if (N == 2) {
27+
return op(x[0], x[1]);
28+
}
29+
return op(reduce<T, leftN, Op>(x, op), reduce<T, rightN, Op>(x + leftN, op));
30+
}
31+
32+
template <class T> class OpAdd {
33+
public:
34+
T operator()(T a, T b) { return a + b; }
35+
};
36+
37+
// Number of trees given number of classes
838
constexpr int fn_classes(int n_classes){
9-
// Number of trees given number of classes
1039
return n_classes == 2 ? 1 : n_classes;
1140
}
1241

@@ -99,23 +128,24 @@ struct BDT{
99128
public:
100129
score_t normalisation;
101130
score_t init_predict[fn_classes(n_classes)];
131+
OpAdd<score_t> op_add;
102132

103-
void tree_scores(input_t x, score_t scores[n_trees][fn_classes(n_classes)]) const;
133+
void tree_scores(input_t x, score_t scores[fn_classes(n_classes)][n_trees]) const;
104134

105135
void decision_function(input_t x, score_t score[fn_classes(n_classes)]) const{
106-
score_t scores[n_trees][fn_classes(n_classes)];
136+
score_t scores[fn_classes(n_classes)][n_trees];
107137
#pragma HLS ARRAY_PARTITION variable=scores dim=0
138+
// Get predictions scores
139+
tree_scores(x, scores);
140+
// Reduce
141+
Reduce:
108142
for(int j = 0; j < fn_classes(n_classes); j++){
143+
// Init predictions
109144
score[j] = init_predict[j];
145+
// Sum predictions from trees via "reduce" method
146+
score[j] += reduce<score_t, n_trees, OpAdd<score_t>>(scores[j], op_add);
110147
}
111-
tree_scores(x, scores);
112-
Trees:
113-
for(int i = 0; i < n_trees; i++){
114-
Classes:
115-
for(int j = 0; j < fn_classes(n_classes); j++){
116-
score[j] += scores[i][j];
117-
}
118-
}
148+
// Normalize predictions
119149
for(int j = 0; j < fn_classes(n_classes); j++){
120150
score[j] *= normalisation;
121151
}

conifer/backends/xilinxhls/hls-template/firmware/BDT_unrolled.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include "parameters.h"
33

44
template<>
5-
void BDT::BDT<n_trees, n_classes, input_arr_t, score_t, threshold_t>::tree_scores(input_arr_t x, score_t scores[n_trees][fn_classes(n_classes)]) const {
5+
void BDT::BDT<n_trees, n_classes, input_arr_t, score_t, threshold_t>::tree_scores(input_arr_t x, score_t scores[fn_classes(n_classes)][n_trees]) const {
66
// conifer insert tree_scores
77
}
88

conifer/backends/xilinxhls/writer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def write_bdt_h(self):
139139
newline = ''
140140
for it, trees in enumerate(self.trees):
141141
for ic, tree in enumerate(trees):
142-
newline += f' scores[{it}][{ic}] = tree_{it}_{ic}.decision_function(x);\n'
142+
newline += f' scores[{ic}][{it}] = tree_{ic}_{it}.decision_function(x);\n'
143143
else:
144144
newline = line
145145
fout.write(newline)
@@ -227,7 +227,7 @@ def _write_parameters_h_unrolled(self, fout):
227227
for iclass, tree in enumerate(trees):
228228
fout.write(f'static const BDT::Tree<{itree*nc+iclass}, {tree.n_nodes()}, {tree.n_leaves()}')
229229
fout.write(f', input_arr_t, score_t, threshold_t>')
230-
fout.write(f' tree_{itree}_{iclass} = {{\n')
230+
fout.write(f' tree_{iclass}_{itree} = {{\n')
231231
# loop over fields
232232
for ifield, field in enumerate(tree_fields):
233233
newline = ' {'

0 commit comments

Comments
 (0)