Skip to content

Commit 6638ddc

Browse files
committed
feat: add distributed query planner — decompose queries across backends with merge operators
1 parent eba6985 commit 6638ddc

File tree

11 files changed

+2272
-1
lines changed

11 files changed

+2272
-1
lines changed

Makefile.new

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ TEST_SRCS = $(TEST_DIR)/test_main.cpp \
5656
$(TEST_DIR)/test_plan_builder.cpp \
5757
$(TEST_DIR)/test_operators.cpp \
5858
$(TEST_DIR)/test_plan_executor.cpp \
59-
$(TEST_DIR)/test_optimizer.cpp
59+
$(TEST_DIR)/test_optimizer.cpp \
60+
$(TEST_DIR)/test_distributed_planner.cpp
6061
TEST_OBJS = $(TEST_SRCS:.cpp=.o)
6162
TEST_TARGET = $(PROJECT_ROOT)/run_tests
6263

include/sql_engine/distributed_planner.h

Lines changed: 724 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
#ifndef SQL_ENGINE_OPERATORS_MERGE_AGGREGATE_OP_H
2+
#define SQL_ENGINE_OPERATORS_MERGE_AGGREGATE_OP_H
3+
4+
#include "sql_engine/operator.h"
5+
#include "sql_engine/value.h"
6+
#include "sql_engine/row.h"
7+
#include "sql_parser/arena.h"
8+
#include <vector>
9+
#include <unordered_map>
10+
#include <string>
11+
#include <cstring>
12+
13+
namespace sql_engine {
14+
15+
// Merge operation types for distributed aggregation
16+
// These define how partial aggregates from multiple shards are combined.
17+
enum class MergeOp : uint8_t {
18+
SUM_OF_COUNTS = 0, // COUNT(*) or COUNT(col): sum the partial counts
19+
SUM_OF_SUMS = 1, // SUM(col): sum the partial sums
20+
MIN_OF_MINS = 2, // MIN(col): min of partial mins
21+
MAX_OF_MAXES = 3, // MAX(col): max of partial maxes
22+
AVG_SUM = 4, // AVG decomposed: this column holds the partial SUM
23+
AVG_COUNT = 5, // AVG decomposed: this column holds the partial COUNT
24+
};
25+
26+
class MergeAggregateOperator : public Operator {
27+
public:
28+
MergeAggregateOperator(std::vector<Operator*> children,
29+
uint16_t group_key_count,
30+
const uint8_t* merge_ops,
31+
uint16_t merge_op_count,
32+
sql_parser::Arena& arena)
33+
: children_(std::move(children)),
34+
group_key_count_(group_key_count),
35+
merge_op_count_(merge_op_count),
36+
arena_(arena)
37+
{
38+
merge_ops_.assign(merge_ops, merge_ops + merge_op_count);
39+
}
40+
41+
void open() override {
42+
groups_.clear();
43+
group_order_.clear();
44+
result_idx_ = 0;
45+
46+
// Open and consume all children
47+
for (auto* child : children_) {
48+
child->open();
49+
Row row{};
50+
while (child->next(row)) {
51+
std::string key = compute_group_key(row);
52+
53+
auto it = groups_.find(key);
54+
if (it == groups_.end()) {
55+
GroupState state;
56+
// Copy group key values
57+
for (uint16_t i = 0; i < group_key_count_; ++i) {
58+
state.group_values.push_back(row.get(i));
59+
}
60+
// Initialize aggregate accumulators
61+
state.agg_values.reserve(merge_op_count_);
62+
for (uint16_t i = 0; i < merge_op_count_; ++i) {
63+
state.agg_values.push_back(value_null());
64+
state.agg_counts[i] = 0;
65+
state.agg_has_value[i] = false;
66+
}
67+
groups_[key] = std::move(state);
68+
group_order_.push_back(key);
69+
it = groups_.find(key);
70+
}
71+
72+
// Merge partial aggregates
73+
merge_row(it->second, row);
74+
}
75+
child->close();
76+
}
77+
}
78+
79+
bool next(Row& out) override {
80+
if (result_idx_ >= group_order_.size()) return false;
81+
82+
const auto& key = group_order_[result_idx_++];
83+
const auto& state = groups_[key];
84+
85+
// Compute output column count: group keys + final agg columns
86+
// We need to count actual output columns (AVG_SUM + AVG_COUNT -> 1 output)
87+
uint16_t output_agg_count = 0;
88+
for (uint16_t i = 0; i < merge_op_count_; ++i) {
89+
if (merge_ops_[i] != static_cast<uint8_t>(MergeOp::AVG_COUNT)) {
90+
output_agg_count++;
91+
}
92+
}
93+
94+
uint16_t cols = group_key_count_ + output_agg_count;
95+
out = make_row(arena_, cols);
96+
97+
for (uint16_t i = 0; i < group_key_count_; ++i) {
98+
out.set(i, state.group_values[i]);
99+
}
100+
101+
uint16_t out_idx = group_key_count_;
102+
for (uint16_t i = 0; i < merge_op_count_; ++i) {
103+
MergeOp op = static_cast<MergeOp>(merge_ops_[i]);
104+
if (op == MergeOp::AVG_COUNT) continue; // consumed by AVG_SUM
105+
106+
if (op == MergeOp::AVG_SUM) {
107+
// Find the corresponding AVG_COUNT (next column)
108+
double sum = state.agg_values[i].is_null() ? 0.0 : state.agg_values[i].to_double();
109+
int64_t count = 0;
110+
// Look for the AVG_COUNT that follows
111+
if (i + 1 < merge_op_count_ &&
112+
merge_ops_[i + 1] == static_cast<uint8_t>(MergeOp::AVG_COUNT)) {
113+
count = state.agg_values[i + 1].is_null() ? 0 : state.agg_values[i + 1].to_int64();
114+
}
115+
if (count > 0) {
116+
out.set(out_idx++, value_double(sum / static_cast<double>(count)));
117+
} else {
118+
out.set(out_idx++, value_null());
119+
}
120+
} else {
121+
out.set(out_idx++, state.agg_values[i]);
122+
}
123+
}
124+
return true;
125+
}
126+
127+
void close() override {
128+
groups_.clear();
129+
group_order_.clear();
130+
}
131+
132+
private:
133+
std::vector<Operator*> children_;
134+
uint16_t group_key_count_;
135+
std::vector<uint8_t> merge_ops_;
136+
uint16_t merge_op_count_;
137+
sql_parser::Arena& arena_;
138+
139+
struct GroupState {
140+
std::vector<Value> group_values;
141+
std::vector<Value> agg_values;
142+
std::unordered_map<uint16_t, int64_t> agg_counts;
143+
std::unordered_map<uint16_t, bool> agg_has_value;
144+
};
145+
146+
std::unordered_map<std::string, GroupState> groups_;
147+
std::vector<std::string> group_order_;
148+
size_t result_idx_ = 0;
149+
150+
std::string compute_group_key(const Row& row) {
151+
std::string key;
152+
for (uint16_t i = 0; i < group_key_count_; ++i) {
153+
const Value& v = row.get(i);
154+
append_value_to_key(key, v);
155+
key += '\x01';
156+
}
157+
return key;
158+
}
159+
160+
static void append_value_to_key(std::string& key, const Value& v) {
161+
if (v.is_null()) { key += "N"; return; }
162+
switch (v.tag) {
163+
case Value::TAG_BOOL: key += v.bool_val ? "T" : "F"; break;
164+
case Value::TAG_INT64: key += std::to_string(v.int_val); break;
165+
case Value::TAG_UINT64: key += std::to_string(v.uint_val); break;
166+
case Value::TAG_DOUBLE: key += std::to_string(v.double_val); break;
167+
case Value::TAG_STRING:
168+
key.append(v.str_val.ptr, v.str_val.len);
169+
break;
170+
default: key += "?"; break;
171+
}
172+
}
173+
174+
void merge_row(GroupState& state, const Row& row) {
175+
for (uint16_t i = 0; i < merge_op_count_; ++i) {
176+
uint16_t col_idx = group_key_count_ + i;
177+
if (col_idx >= row.column_count) continue;
178+
Value v = row.get(col_idx);
179+
180+
MergeOp op = static_cast<MergeOp>(merge_ops_[i]);
181+
switch (op) {
182+
case MergeOp::SUM_OF_COUNTS:
183+
case MergeOp::SUM_OF_SUMS:
184+
case MergeOp::AVG_SUM:
185+
case MergeOp::AVG_COUNT: {
186+
if (v.is_null()) break;
187+
if (!state.agg_has_value[i]) {
188+
state.agg_values[i] = value_double(v.to_double());
189+
state.agg_has_value[i] = true;
190+
} else {
191+
double cur = state.agg_values[i].to_double();
192+
state.agg_values[i] = value_double(cur + v.to_double());
193+
}
194+
break;
195+
}
196+
case MergeOp::MIN_OF_MINS: {
197+
if (v.is_null()) break;
198+
if (!state.agg_has_value[i] || compare_values(v, state.agg_values[i]) < 0) {
199+
state.agg_values[i] = v;
200+
state.agg_has_value[i] = true;
201+
}
202+
break;
203+
}
204+
case MergeOp::MAX_OF_MAXES: {
205+
if (v.is_null()) break;
206+
if (!state.agg_has_value[i] || compare_values(v, state.agg_values[i]) > 0) {
207+
state.agg_values[i] = v;
208+
state.agg_has_value[i] = true;
209+
}
210+
break;
211+
}
212+
}
213+
}
214+
}
215+
216+
static int compare_values(const Value& a, const Value& b) {
217+
if (a.is_null() && b.is_null()) return 0;
218+
if (a.is_null()) return -1;
219+
if (b.is_null()) return 1;
220+
if (a.is_numeric() && b.is_numeric()) {
221+
double da = a.to_double(), db = b.to_double();
222+
return da < db ? -1 : (da > db ? 1 : 0);
223+
}
224+
if (a.tag == Value::TAG_STRING && b.tag == Value::TAG_STRING) {
225+
uint32_t minlen = a.str_val.len < b.str_val.len ? a.str_val.len : b.str_val.len;
226+
int cmp = std::memcmp(a.str_val.ptr, b.str_val.ptr, minlen);
227+
if (cmp != 0) return cmp;
228+
return a.str_val.len < b.str_val.len ? -1 : (a.str_val.len > b.str_val.len ? 1 : 0);
229+
}
230+
return 0;
231+
}
232+
};
233+
234+
} // namespace sql_engine
235+
236+
#endif // SQL_ENGINE_OPERATORS_MERGE_AGGREGATE_OP_H
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#ifndef SQL_ENGINE_OPERATORS_MERGE_SORT_OP_H
2+
#define SQL_ENGINE_OPERATORS_MERGE_SORT_OP_H
3+
4+
#include "sql_engine/operator.h"
5+
#include "sql_engine/value.h"
6+
#include "sql_engine/row.h"
7+
#include <vector>
8+
#include <queue>
9+
#include <functional>
10+
#include <cstring>
11+
12+
namespace sql_engine {
13+
14+
// N-way merge sort operator.
15+
// Takes N child operators, each returning pre-sorted rows, and performs
16+
// an N-way merge using a min-heap to produce globally sorted output.
17+
class MergeSortOperator : public Operator {
18+
public:
19+
MergeSortOperator(std::vector<Operator*> children,
20+
const uint16_t* sort_col_indices,
21+
const uint8_t* directions,
22+
uint16_t key_count)
23+
: children_(std::move(children)), key_count_(key_count)
24+
{
25+
sort_cols_.assign(sort_col_indices, sort_col_indices + key_count);
26+
directions_.assign(directions, directions + key_count);
27+
}
28+
29+
void open() override {
30+
// Open all children and get first row from each
31+
heads_.resize(children_.size());
32+
has_row_.assign(children_.size(), false);
33+
34+
for (size_t i = 0; i < children_.size(); ++i) {
35+
children_[i]->open();
36+
Row row{};
37+
if (children_[i]->next(row)) {
38+
heads_[i] = row;
39+
has_row_[i] = true;
40+
}
41+
}
42+
43+
// Build min-heap using std::function comparator
44+
// Comparator: returns true if a should come AFTER b in the output
45+
// (priority_queue is max-heap, so we invert)
46+
std::function<bool(size_t, size_t)> cmp = [this](size_t a, size_t b) -> bool {
47+
return compare_rows(heads_[a], heads_[b]) > 0;
48+
};
49+
50+
heap_ = HeapType(cmp);
51+
for (size_t i = 0; i < children_.size(); ++i) {
52+
if (has_row_[i]) {
53+
heap_.push(i);
54+
}
55+
}
56+
}
57+
58+
bool next(Row& out) override {
59+
if (heap_.empty()) return false;
60+
61+
size_t idx = heap_.top();
62+
heap_.pop();
63+
64+
out = heads_[idx];
65+
66+
// Advance that child
67+
Row row{};
68+
if (children_[idx]->next(row)) {
69+
heads_[idx] = row;
70+
heap_.push(idx);
71+
} else {
72+
has_row_[idx] = false;
73+
}
74+
75+
return true;
76+
}
77+
78+
void close() override {
79+
for (auto* child : children_) {
80+
child->close();
81+
}
82+
// Clear the heap by assigning an empty one
83+
while (!heap_.empty()) heap_.pop();
84+
}
85+
86+
private:
87+
std::vector<Operator*> children_;
88+
std::vector<uint16_t> sort_cols_;
89+
std::vector<uint8_t> directions_;
90+
uint16_t key_count_;
91+
92+
std::vector<Row> heads_;
93+
std::vector<bool> has_row_;
94+
95+
using HeapType = std::priority_queue<size_t, std::vector<size_t>,
96+
std::function<bool(size_t, size_t)>>;
97+
HeapType heap_{[](size_t, size_t) { return false; }};
98+
99+
// Compare two rows according to sort keys.
100+
// Returns <0 if a < b, 0 if equal, >0 if a > b.
101+
int compare_rows(const Row& a, const Row& b) const {
102+
for (uint16_t k = 0; k < key_count_; ++k) {
103+
uint16_t col = sort_cols_[k];
104+
Value va = (col < a.column_count) ? a.get(col) : value_null();
105+
Value vb = (col < b.column_count) ? b.get(col) : value_null();
106+
int cmp = compare_values(va, vb);
107+
if (cmp == 0) continue;
108+
bool asc = (directions_[k] == 0);
109+
return asc ? cmp : -cmp;
110+
}
111+
return 0;
112+
}
113+
114+
static int compare_values(const Value& a, const Value& b) {
115+
if (a.is_null() && b.is_null()) return 0;
116+
if (a.is_null()) return -1;
117+
if (b.is_null()) return 1;
118+
119+
if (a.is_numeric() && b.is_numeric()) {
120+
double da = a.to_double(), db = b.to_double();
121+
return da < db ? -1 : (da > db ? 1 : 0);
122+
}
123+
124+
if (a.tag == Value::TAG_STRING && b.tag == Value::TAG_STRING) {
125+
uint32_t minlen = a.str_val.len < b.str_val.len ? a.str_val.len : b.str_val.len;
126+
int cmp = std::memcmp(a.str_val.ptr, b.str_val.ptr, minlen);
127+
if (cmp != 0) return cmp;
128+
return a.str_val.len < b.str_val.len ? -1 : (a.str_val.len > b.str_val.len ? 1 : 0);
129+
}
130+
131+
return 0;
132+
}
133+
};
134+
135+
} // namespace sql_engine
136+
137+
#endif // SQL_ENGINE_OPERATORS_MERGE_SORT_OP_H

0 commit comments

Comments
 (0)