|
| 1 | +#ifndef SQL_ENGINE_OPERATORS_MERGE_AGGREGATE_OP_H |
| 2 | +#define SQL_ENGINE_OPERATORS_MERGE_AGGREGATE_OP_H |
| 3 | + |
| 4 | +#include "sql_engine/operator.h" |
| 5 | +#include "sql_engine/value.h" |
| 6 | +#include "sql_engine/row.h" |
| 7 | +#include "sql_parser/arena.h" |
| 8 | +#include <vector> |
| 9 | +#include <unordered_map> |
| 10 | +#include <string> |
| 11 | +#include <cstring> |
| 12 | + |
| 13 | +namespace sql_engine { |
| 14 | + |
| 15 | +// Merge operation types for distributed aggregation |
| 16 | +// These define how partial aggregates from multiple shards are combined. |
| 17 | +enum class MergeOp : uint8_t { |
| 18 | + SUM_OF_COUNTS = 0, // COUNT(*) or COUNT(col): sum the partial counts |
| 19 | + SUM_OF_SUMS = 1, // SUM(col): sum the partial sums |
| 20 | + MIN_OF_MINS = 2, // MIN(col): min of partial mins |
| 21 | + MAX_OF_MAXES = 3, // MAX(col): max of partial maxes |
| 22 | + AVG_SUM = 4, // AVG decomposed: this column holds the partial SUM |
| 23 | + AVG_COUNT = 5, // AVG decomposed: this column holds the partial COUNT |
| 24 | +}; |
| 25 | + |
| 26 | +class MergeAggregateOperator : public Operator { |
| 27 | +public: |
| 28 | + MergeAggregateOperator(std::vector<Operator*> children, |
| 29 | + uint16_t group_key_count, |
| 30 | + const uint8_t* merge_ops, |
| 31 | + uint16_t merge_op_count, |
| 32 | + sql_parser::Arena& arena) |
| 33 | + : children_(std::move(children)), |
| 34 | + group_key_count_(group_key_count), |
| 35 | + merge_op_count_(merge_op_count), |
| 36 | + arena_(arena) |
| 37 | + { |
| 38 | + merge_ops_.assign(merge_ops, merge_ops + merge_op_count); |
| 39 | + } |
| 40 | + |
| 41 | + void open() override { |
| 42 | + groups_.clear(); |
| 43 | + group_order_.clear(); |
| 44 | + result_idx_ = 0; |
| 45 | + |
| 46 | + // Open and consume all children |
| 47 | + for (auto* child : children_) { |
| 48 | + child->open(); |
| 49 | + Row row{}; |
| 50 | + while (child->next(row)) { |
| 51 | + std::string key = compute_group_key(row); |
| 52 | + |
| 53 | + auto it = groups_.find(key); |
| 54 | + if (it == groups_.end()) { |
| 55 | + GroupState state; |
| 56 | + // Copy group key values |
| 57 | + for (uint16_t i = 0; i < group_key_count_; ++i) { |
| 58 | + state.group_values.push_back(row.get(i)); |
| 59 | + } |
| 60 | + // Initialize aggregate accumulators |
| 61 | + state.agg_values.reserve(merge_op_count_); |
| 62 | + for (uint16_t i = 0; i < merge_op_count_; ++i) { |
| 63 | + state.agg_values.push_back(value_null()); |
| 64 | + state.agg_counts[i] = 0; |
| 65 | + state.agg_has_value[i] = false; |
| 66 | + } |
| 67 | + groups_[key] = std::move(state); |
| 68 | + group_order_.push_back(key); |
| 69 | + it = groups_.find(key); |
| 70 | + } |
| 71 | + |
| 72 | + // Merge partial aggregates |
| 73 | + merge_row(it->second, row); |
| 74 | + } |
| 75 | + child->close(); |
| 76 | + } |
| 77 | + } |
| 78 | + |
| 79 | + bool next(Row& out) override { |
| 80 | + if (result_idx_ >= group_order_.size()) return false; |
| 81 | + |
| 82 | + const auto& key = group_order_[result_idx_++]; |
| 83 | + const auto& state = groups_[key]; |
| 84 | + |
| 85 | + // Compute output column count: group keys + final agg columns |
| 86 | + // We need to count actual output columns (AVG_SUM + AVG_COUNT -> 1 output) |
| 87 | + uint16_t output_agg_count = 0; |
| 88 | + for (uint16_t i = 0; i < merge_op_count_; ++i) { |
| 89 | + if (merge_ops_[i] != static_cast<uint8_t>(MergeOp::AVG_COUNT)) { |
| 90 | + output_agg_count++; |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + uint16_t cols = group_key_count_ + output_agg_count; |
| 95 | + out = make_row(arena_, cols); |
| 96 | + |
| 97 | + for (uint16_t i = 0; i < group_key_count_; ++i) { |
| 98 | + out.set(i, state.group_values[i]); |
| 99 | + } |
| 100 | + |
| 101 | + uint16_t out_idx = group_key_count_; |
| 102 | + for (uint16_t i = 0; i < merge_op_count_; ++i) { |
| 103 | + MergeOp op = static_cast<MergeOp>(merge_ops_[i]); |
| 104 | + if (op == MergeOp::AVG_COUNT) continue; // consumed by AVG_SUM |
| 105 | + |
| 106 | + if (op == MergeOp::AVG_SUM) { |
| 107 | + // Find the corresponding AVG_COUNT (next column) |
| 108 | + double sum = state.agg_values[i].is_null() ? 0.0 : state.agg_values[i].to_double(); |
| 109 | + int64_t count = 0; |
| 110 | + // Look for the AVG_COUNT that follows |
| 111 | + if (i + 1 < merge_op_count_ && |
| 112 | + merge_ops_[i + 1] == static_cast<uint8_t>(MergeOp::AVG_COUNT)) { |
| 113 | + count = state.agg_values[i + 1].is_null() ? 0 : state.agg_values[i + 1].to_int64(); |
| 114 | + } |
| 115 | + if (count > 0) { |
| 116 | + out.set(out_idx++, value_double(sum / static_cast<double>(count))); |
| 117 | + } else { |
| 118 | + out.set(out_idx++, value_null()); |
| 119 | + } |
| 120 | + } else { |
| 121 | + out.set(out_idx++, state.agg_values[i]); |
| 122 | + } |
| 123 | + } |
| 124 | + return true; |
| 125 | + } |
| 126 | + |
| 127 | + void close() override { |
| 128 | + groups_.clear(); |
| 129 | + group_order_.clear(); |
| 130 | + } |
| 131 | + |
| 132 | +private: |
| 133 | + std::vector<Operator*> children_; |
| 134 | + uint16_t group_key_count_; |
| 135 | + std::vector<uint8_t> merge_ops_; |
| 136 | + uint16_t merge_op_count_; |
| 137 | + sql_parser::Arena& arena_; |
| 138 | + |
| 139 | + struct GroupState { |
| 140 | + std::vector<Value> group_values; |
| 141 | + std::vector<Value> agg_values; |
| 142 | + std::unordered_map<uint16_t, int64_t> agg_counts; |
| 143 | + std::unordered_map<uint16_t, bool> agg_has_value; |
| 144 | + }; |
| 145 | + |
| 146 | + std::unordered_map<std::string, GroupState> groups_; |
| 147 | + std::vector<std::string> group_order_; |
| 148 | + size_t result_idx_ = 0; |
| 149 | + |
| 150 | + std::string compute_group_key(const Row& row) { |
| 151 | + std::string key; |
| 152 | + for (uint16_t i = 0; i < group_key_count_; ++i) { |
| 153 | + const Value& v = row.get(i); |
| 154 | + append_value_to_key(key, v); |
| 155 | + key += '\x01'; |
| 156 | + } |
| 157 | + return key; |
| 158 | + } |
| 159 | + |
| 160 | + static void append_value_to_key(std::string& key, const Value& v) { |
| 161 | + if (v.is_null()) { key += "N"; return; } |
| 162 | + switch (v.tag) { |
| 163 | + case Value::TAG_BOOL: key += v.bool_val ? "T" : "F"; break; |
| 164 | + case Value::TAG_INT64: key += std::to_string(v.int_val); break; |
| 165 | + case Value::TAG_UINT64: key += std::to_string(v.uint_val); break; |
| 166 | + case Value::TAG_DOUBLE: key += std::to_string(v.double_val); break; |
| 167 | + case Value::TAG_STRING: |
| 168 | + key.append(v.str_val.ptr, v.str_val.len); |
| 169 | + break; |
| 170 | + default: key += "?"; break; |
| 171 | + } |
| 172 | + } |
| 173 | + |
| 174 | + void merge_row(GroupState& state, const Row& row) { |
| 175 | + for (uint16_t i = 0; i < merge_op_count_; ++i) { |
| 176 | + uint16_t col_idx = group_key_count_ + i; |
| 177 | + if (col_idx >= row.column_count) continue; |
| 178 | + Value v = row.get(col_idx); |
| 179 | + |
| 180 | + MergeOp op = static_cast<MergeOp>(merge_ops_[i]); |
| 181 | + switch (op) { |
| 182 | + case MergeOp::SUM_OF_COUNTS: |
| 183 | + case MergeOp::SUM_OF_SUMS: |
| 184 | + case MergeOp::AVG_SUM: |
| 185 | + case MergeOp::AVG_COUNT: { |
| 186 | + if (v.is_null()) break; |
| 187 | + if (!state.agg_has_value[i]) { |
| 188 | + state.agg_values[i] = value_double(v.to_double()); |
| 189 | + state.agg_has_value[i] = true; |
| 190 | + } else { |
| 191 | + double cur = state.agg_values[i].to_double(); |
| 192 | + state.agg_values[i] = value_double(cur + v.to_double()); |
| 193 | + } |
| 194 | + break; |
| 195 | + } |
| 196 | + case MergeOp::MIN_OF_MINS: { |
| 197 | + if (v.is_null()) break; |
| 198 | + if (!state.agg_has_value[i] || compare_values(v, state.agg_values[i]) < 0) { |
| 199 | + state.agg_values[i] = v; |
| 200 | + state.agg_has_value[i] = true; |
| 201 | + } |
| 202 | + break; |
| 203 | + } |
| 204 | + case MergeOp::MAX_OF_MAXES: { |
| 205 | + if (v.is_null()) break; |
| 206 | + if (!state.agg_has_value[i] || compare_values(v, state.agg_values[i]) > 0) { |
| 207 | + state.agg_values[i] = v; |
| 208 | + state.agg_has_value[i] = true; |
| 209 | + } |
| 210 | + break; |
| 211 | + } |
| 212 | + } |
| 213 | + } |
| 214 | + } |
| 215 | + |
| 216 | + static int compare_values(const Value& a, const Value& b) { |
| 217 | + if (a.is_null() && b.is_null()) return 0; |
| 218 | + if (a.is_null()) return -1; |
| 219 | + if (b.is_null()) return 1; |
| 220 | + if (a.is_numeric() && b.is_numeric()) { |
| 221 | + double da = a.to_double(), db = b.to_double(); |
| 222 | + return da < db ? -1 : (da > db ? 1 : 0); |
| 223 | + } |
| 224 | + if (a.tag == Value::TAG_STRING && b.tag == Value::TAG_STRING) { |
| 225 | + uint32_t minlen = a.str_val.len < b.str_val.len ? a.str_val.len : b.str_val.len; |
| 226 | + int cmp = std::memcmp(a.str_val.ptr, b.str_val.ptr, minlen); |
| 227 | + if (cmp != 0) return cmp; |
| 228 | + return a.str_val.len < b.str_val.len ? -1 : (a.str_val.len > b.str_val.len ? 1 : 0); |
| 229 | + } |
| 230 | + return 0; |
| 231 | + } |
| 232 | +}; |
| 233 | + |
| 234 | +} // namespace sql_engine |
| 235 | + |
| 236 | +#endif // SQL_ENGINE_OPERATORS_MERGE_AGGREGATE_OP_H |
0 commit comments