diff --git a/velox/experimental/cudf/CudfConfig.h b/velox/experimental/cudf/CudfConfig.h index 92657d8fc7e..10e98e64fcc 100644 --- a/velox/experimental/cudf/CudfConfig.h +++ b/velox/experimental/cudf/CudfConfig.h @@ -35,6 +35,8 @@ struct CudfConfig { "cudf.ast_expression_priority"}; static constexpr const char* kCudfAllowCpuFallback{"cudf.allow_cpu_fallback"}; static constexpr const char* kCudfLogFallback{"cudf.log_fallback"}; + static constexpr const char* kCudfSortMergeJoinCardinalityThreshold{ + "cudf.sort_merge_join.cardinality_threshold"}; /// Singleton CudfConfig instance. /// Clients must set the configs below before invoking registerCudf(). @@ -78,6 +80,11 @@ struct CudfConfig { /// Whether to log a reason for falling back to Velox CPU execution. bool logFallback{true}; + + /// Cardinality threshold for sort-merge join. When the ratio of distinct keys + /// to total rows in the build table is below this threshold, sort-merge join + /// is used instead of hash join for inner and left joins. + double sortMergeJoinCardinalityThreshold{0.1}; }; } // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/exec/CudfHashJoin.cpp b/velox/experimental/cudf/exec/CudfHashJoin.cpp index 6de88b23348..991673b46e3 100644 --- a/velox/experimental/cudf/exec/CudfHashJoin.cpp +++ b/velox/experimental/cudf/exec/CudfHashJoin.cpp @@ -103,6 +103,50 @@ std::optional CudfHashJoinBridge::getBuildStream() { return buildStream_; } +void CudfHashJoinBridge::setJoinData( + std::optional joinData) { + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "Calling CudfHashJoinBridge::setJoinData" << std::endl; + } + std::vector promises; + { + std::lock_guard l(mutex_); + VELOX_CHECK( + !joinData_.has_value(), "CudfHashJoinBridge already has join data"); + joinData_ = std::move(joinData); + promises = std::move(promises_); + } + notify(std::move(promises)); +} + +std::optional +CudfHashJoinBridge::joinDataOrFuture(ContinueFuture* future) { + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "Calling CudfHashJoinBridge::joinDataOrFuture" << std::endl; + } + std::lock_guard l(mutex_); + if (joinData_.has_value()) { + return joinData_; + } + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) + << "Calling CudfHashJoinBridge::joinDataOrFuture constructing promise" + << std::endl; + } + promises_.emplace_back("CudfHashJoinBridge::joinDataOrFuture"); + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "Calling CudfHashJoinBridge::joinDataOrFuture getSemiFuture" + << std::endl; + } + *future = promises_.back().getSemiFuture(); + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) + << "Calling CudfHashJoinBridge::joinDataOrFuture returning nullopt" + << std::endl; + } + return std::nullopt; +} + CudfHashJoinBuild::CudfHashJoinBuild( int32_t operatorId, exec::DriverCtx* driverCtx, @@ -227,48 +271,155 @@ void CudfHashJoinBuild::noMoreInput() { buildType->getChildIdx(rightKeys[i]->name())); } + // Determine join strategy based on cardinality + auto strategy = determineJoinStrategy(tbls, buildKeyIndices, stream); + + // Build the appropriate join objects + if (strategy == CudfHashJoinBridge::JoinStrategy::kSortMergeJoin) { + buildSortMergeJoin(std::move(tbls), buildKeyIndices, stream); + } else { + buildHashJoin(std::move(tbls), buildKeyIndices, stream); + } +} + +CudfHashJoinBridge::JoinStrategy CudfHashJoinBuild::determineJoinStrategy( + const std::vector>& tbls, + const std::vector& buildKeyIndices, + rmm::cuda_stream_view stream) { + // Sort-merge join only supports inner and left joins + if (!joinNode_->isInnerJoin() && !joinNode_->isLeftJoin()) { + return CudfHashJoinBridge::JoinStrategy::kHashJoin; + } + + // Compute total rows across all batches + size_t totalRows = 0; + for (const auto& tbl : tbls) { + totalRows += static_cast(tbl->num_rows()); + } + + if (totalRows == 0) { + return CudfHashJoinBridge::JoinStrategy::kHashJoin; + } + + // Estimate distinct count across all batches + // For simplicity, use the first batch to estimate cardinality ratio + // A more accurate approach would combine estimates from all batches + size_t distinctCount = + estimateDistinctCount(tbls[0]->view(), buildKeyIndices, stream); + + double cardinalityRatio = + static_cast(distinctCount) / static_cast(totalRows); + + auto threshold = CudfConfig::getInstance().sortMergeJoinCardinalityThreshold; + + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "Join strategy selection: distinctCount=" << distinctCount + << ", totalRows=" << totalRows + << ", cardinalityRatio=" << cardinalityRatio + << ", threshold=" << threshold << std::endl; + } + + if (cardinalityRatio < threshold) { + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "Selecting sort-merge join (cardinalityRatio < threshold)" + << std::endl; + } + return CudfHashJoinBridge::JoinStrategy::kSortMergeJoin; + } + + return CudfHashJoinBridge::JoinStrategy::kHashJoin; +} + +void CudfHashJoinBuild::buildHashJoin( + std::vector>&& tbls, + const std::vector& buildKeyIndices, + rmm::cuda_stream_view stream) { + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "Building hash join objects" << std::endl; + } + // Only need to construct hash_join object if it's an inner join, left join or - // right join. - // All other cases use a standalone function in cudf - bool buildHashJoin = + // right join. All other cases use a standalone function in cudf + bool needHashJoinObject = (joinNode_->isInnerJoin() || joinNode_->isLeftJoin() || joinNode_->isRightJoin()); std::vector> hashObjects; - for (auto i = 0; i < tbls.size(); i++) { + for (size_t i = 0; i < tbls.size(); i++) { hashObjects.push_back( - (buildHashJoin) ? std::make_shared( - tbls[i]->view().select(buildKeyIndices), - cudf::null_equality::UNEQUAL, - stream) - : nullptr); - if (buildHashJoin) { + needHashJoinObject ? std::make_shared( + tbls[i]->view().select(buildKeyIndices), + cudf::null_equality::UNEQUAL, + stream) + : nullptr); + if (needHashJoinObject) { VELOX_CHECK_NOT_NULL(hashObjects.back()); } if (CudfConfig::getInstance().debugEnabled) { if (hashObjects.back() != nullptr) { - LOG(INFO) << "hashObject " << i << " is not nullptr " - << hashObjects.back().get() << "\n"; + LOG(INFO) << "hashObject " << i + << " created: " << hashObjects.back().get() << std::endl; } else { - LOG(INFO) << "hashObject " << i << " is *** nullptr\n"; + LOG(INFO) << "hashObject " << i << " is nullptr" << std::endl; } } } - std::vector> shared_tbls; + std::vector> sharedTbls; for (auto& tbl : tbls) { - shared_tbls.push_back(std::move(tbl)); + sharedTbls.push_back(std::move(tbl)); } - // set hash table to CudfHashJoinBridge + auto joinBridge = operatorCtx_->task()->getCustomJoinBridge( operatorCtx_->driverCtx()->splitGroupId, planNodeId()); auto cudfHashJoinBridge = std::dynamic_pointer_cast(joinBridge); cudfHashJoinBridge->setBuildStream(stream); - cudfHashJoinBridge->setHashTable( - std::make_optional( - std::make_pair(std::move(shared_tbls), std::move(hashObjects)))); + cudfHashJoinBridge->setJoinData(std::make_optional( + CudfHashJoinBridge::join_type{std::in_place_type, + std::move(sharedTbls), + std::move(hashObjects)})); +} + +void CudfHashJoinBuild::buildSortMergeJoin( + std::vector>&& tbls, + const std::vector& buildKeyIndices, + rmm::cuda_stream_view stream) { + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "Building sort-merge join objects" << std::endl; + } + + std::vector> smjObjects; + for (size_t i = 0; i < tbls.size(); i++) { + smjObjects.push_back(std::make_shared( + tbls[i]->view().select(buildKeyIndices), + cudf::sorted::NO, + cudf::null_equality::UNEQUAL, + stream)); + VELOX_CHECK_NOT_NULL(smjObjects.back()); + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "sortMergeJoinObject " << i + << " created: " << smjObjects.back().get() << std::endl; + } + } + + std::vector> sharedTbls; + for (auto& tbl : tbls) { + sharedTbls.push_back(std::move(tbl)); + } + + auto joinBridge = operatorCtx_->task()->getCustomJoinBridge( + operatorCtx_->driverCtx()->splitGroupId, planNodeId()); + auto cudfHashJoinBridge = + std::dynamic_pointer_cast(joinBridge); + + cudfHashJoinBridge->setBuildStream(stream); + cudfHashJoinBridge->setJoinData( + std::make_optional(CudfHashJoinBridge::join_type{ + std::in_place_type, + std::move(sharedTbls), + std::move(smjObjects)})); } exec::BlockingReason CudfHashJoinBuild::isBlocked(ContinueFuture* future) { @@ -1083,13 +1234,114 @@ std::vector> CudfHashJoinProbe::antiJoin( return cudfOutputs; } +std::vector> CudfHashJoinProbe::innerJoinSortMerge( + cudf::table_view leftTableView, + rmm::cuda_stream_view stream) { + std::vector> cudfOutputs; + + auto& rightTables = sortMergeJoinData_.value().first; + auto& smjs = sortMergeJoinData_.value().second; + for (size_t i = 0; i < rightTables.size(); i++) { + auto rightTableView = rightTables[i]->view(); + auto& smj = smjs[i]; + + VELOX_CHECK_NOT_NULL(smj); + if (buildStream_.has_value()) { + cudaEvent_->recordFrom(stream).waitOn(buildStream_.value()); + } + auto [leftJoinIndices, rightJoinIndices] = smj->inner_join( + leftTableView.select(leftKeyIndices_), + cudf::sorted::NO, + buildStream_.has_value() ? buildStream_.value() : stream); + if (buildStream_.has_value()) { + cudaEvent_->recordFrom(buildStream_.value()).waitOn(stream); + } + + auto leftIndicesSpan = + cudf::device_span{*leftJoinIndices}; + auto rightIndicesSpan = + cudf::device_span{*rightJoinIndices}; + auto leftIndicesCol = cudf::column_view{leftIndicesSpan}; + auto rightIndicesCol = cudf::column_view{rightIndicesSpan}; + + if (joinNode_->filter()) { + cudfOutputs.push_back(filteredOutputIndices( + leftTableView, + leftIndicesCol, + rightTableView, + rightIndicesCol, + cudf::join_kind::INNER_JOIN, + stream)); + } else { + cudfOutputs.push_back(unfilteredOutput( + leftTableView, + leftIndicesCol, + rightTableView, + rightIndicesCol, + stream)); + } + } + return cudfOutputs; +} + +std::vector> CudfHashJoinProbe::leftJoinSortMerge( + cudf::table_view leftTableView, + rmm::cuda_stream_view stream) { + std::vector> cudfOutputs; + + auto& rightTables = sortMergeJoinData_.value().first; + auto& smjs = sortMergeJoinData_.value().second; + for (size_t i = 0; i < rightTables.size(); i++) { + auto rightTableView = rightTables[i]->view(); + auto& smj = smjs[i]; + + VELOX_CHECK_NOT_NULL(smj); + if (buildStream_.has_value()) { + cudaEvent_->recordFrom(stream).waitOn(buildStream_.value()); + } + auto [leftJoinIndices, rightJoinIndices] = smj->left_join( + leftTableView.select(leftKeyIndices_), + cudf::sorted::NO, + buildStream_.has_value() ? buildStream_.value() : stream); + if (buildStream_.has_value()) { + cudaEvent_->recordFrom(buildStream_.value()).waitOn(stream); + } + + auto leftIndicesSpan = + cudf::device_span{*leftJoinIndices}; + auto rightIndicesSpan = + cudf::device_span{*rightJoinIndices}; + auto leftIndicesCol = cudf::column_view{leftIndicesSpan}; + auto rightIndicesCol = cudf::column_view{rightIndicesSpan}; + + if (joinNode_->filter()) { + cudfOutputs.push_back(filteredOutputIndices( + leftTableView, + leftIndicesCol, + rightTableView, + rightIndicesCol, + cudf::join_kind::LEFT_JOIN, + stream)); + } else { + cudfOutputs.push_back(unfilteredOutput( + leftTableView, + leftIndicesCol, + rightTableView, + rightIndicesCol, + stream)); + } + } + return cudfOutputs; +} + RowVectorPtr CudfHashJoinProbe::getOutput() { if (CudfConfig::getInstance().debugEnabled) { LOG(INFO) << "Calling CudfHashJoinProbe::getOutput" << std::endl; } VELOX_NVTX_OPERATOR_FUNC_RANGE(); - if (finished_ or !hashObject_.has_value()) { + bool hasJoinData = hashObject_.has_value() || sortMergeJoinData_.has_value(); + if (finished_ || !hasJoinData) { return nullptr; } if (!input_) { @@ -1173,44 +1425,75 @@ RowVectorPtr CudfHashJoinProbe::getOutput() { << std::endl; } - auto& rightTables = hashObject_.value().first; - auto& hbs = hashObject_.value().second; - for (auto i = 0; i < rightTables.size(); i++) { - auto& rightTable = rightTables[i]; - auto& hb = hbs[i]; - VELOX_CHECK_NOT_NULL(rightTable); - if (CudfConfig::getInstance().debugEnabled) { - if (rightTable != nullptr) - LOG(INFO) << "right_table is not nullptr " << rightTable.get() - << " hasValue(" << hashObject_.has_value() << ")\n"; - if (hb != nullptr) - LOG(INFO) << "hb is not nullptr " << hb.get() << " hasValue(" - << hashObject_.has_value() << ")\n"; + // Validate build tables + if (hashObject_.has_value()) { + auto& rightTables = hashObject_.value().first; + auto& hbs = hashObject_.value().second; + for (size_t i = 0; i < rightTables.size(); i++) { + auto& rightTable = rightTables[i]; + auto& hb = hbs[i]; + VELOX_CHECK_NOT_NULL(rightTable); + if (CudfConfig::getInstance().debugEnabled) { + if (rightTable != nullptr) + LOG(INFO) << "right_table is not nullptr " << rightTable.get() + << " hasValue(" << hashObject_.has_value() << ")\n"; + if (hb != nullptr) + LOG(INFO) << "hb is not nullptr " << hb.get() << " hasValue(" + << hashObject_.has_value() << ")\n"; + } + } + } else if (sortMergeJoinData_.has_value()) { + auto& rightTables = sortMergeJoinData_.value().first; + for (size_t i = 0; i < rightTables.size(); i++) { + VELOX_CHECK_NOT_NULL(rightTables[i]); + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "sort_merge right_table " << i << " is not nullptr " + << rightTables[i].get() << std::endl; + } } } std::vector> cudfOutputs; - switch (joinNode_->joinType()) { - case core::JoinType::kInner: - cudfOutputs = innerJoin(leftTableView, stream); - break; - case core::JoinType::kLeft: - cudfOutputs = leftJoin(leftTableView, stream); - break; - case core::JoinType::kRight: - cudfOutputs = rightJoin(leftTableView, stream); - break; - case core::JoinType::kLeftSemiFilter: - cudfOutputs = leftSemiFilterJoin(leftTableView, stream); - break; - case core::JoinType::kRightSemiFilter: - cudfOutputs = rightSemiFilterJoin(leftTableView, stream); - break; - case core::JoinType::kAnti: - cudfOutputs = antiJoin(leftTableView, stream); - break; - default: - VELOX_FAIL("Unsupported join type: ", joinNode_->joinType()); + + // Dispatch based on join strategy and join type + if (sortMergeJoinData_.has_value()) { + // Sort-merge join path (inner and left joins only) + switch (joinNode_->joinType()) { + case core::JoinType::kInner: + cudfOutputs = innerJoinSortMerge(leftTableView, stream); + break; + case core::JoinType::kLeft: + cudfOutputs = leftJoinSortMerge(leftTableView, stream); + break; + default: + VELOX_FAIL( + "Sort-merge join does not support join type: {}", + static_cast(joinNode_->joinType())); + } + } else { + // Hash join path + switch (joinNode_->joinType()) { + case core::JoinType::kInner: + cudfOutputs = innerJoin(leftTableView, stream); + break; + case core::JoinType::kLeft: + cudfOutputs = leftJoin(leftTableView, stream); + break; + case core::JoinType::kRight: + cudfOutputs = rightJoin(leftTableView, stream); + break; + case core::JoinType::kLeftSemiFilter: + cudfOutputs = leftSemiFilterJoin(leftTableView, stream); + break; + case core::JoinType::kRightSemiFilter: + cudfOutputs = rightSemiFilterJoin(leftTableView, stream); + break; + case core::JoinType::kAnti: + cudfOutputs = antiJoin(leftTableView, stream); + break; + default: + VELOX_FAIL("Unsupported join type: ", joinNode_->joinType()); + } } // Release input CudfVector to free GPU memory before creating output. @@ -1242,8 +1525,11 @@ bool CudfHashJoinProbe::skipProbeOnEmptyBuild() const { } exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { + // Check if we already have join data (hash or sort-merge) + bool hasJoinData = hashObject_.has_value() || sortMergeJoinData_.has_value(); + if ((joinNode_->isRightJoin() || joinNode_->isRightSemiFilterJoin()) && - hashObject_.has_value()) { + hasJoinData) { if (!future_.valid()) { return exec::BlockingReason::kNotBlocked; } @@ -1251,7 +1537,7 @@ exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { return exec::BlockingReason::kWaitForJoinProbe; } - if (hashObject_.has_value()) { + if (hasJoinData) { return exec::BlockingReason::kNotBlocked; } @@ -1261,20 +1547,33 @@ exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { std::dynamic_pointer_cast(joinBridge); VELOX_CHECK_NOT_NULL(cudfJoinBridge); VELOX_CHECK_NOT_NULL(future); - auto hashObject = cudfJoinBridge->hashOrFuture(future); + auto joinData = cudfJoinBridge->joinDataOrFuture(future); - if (!hashObject.has_value()) { + if (!joinData.has_value()) { if (CudfConfig::getInstance().debugEnabled) { LOG(INFO) << "CudfHashJoinProbe is blocked, waiting for join build" << std::endl; } return exec::BlockingReason::kWaitForJoinBuild; } - hashObject_ = std::move(hashObject); + + // Store join data based on variant type + if (std::holds_alternative(*joinData)) { + hashObject_ = std::get(*joinData); + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "CudfHashJoinProbe received hash join data" << std::endl; + } + } else { + sortMergeJoinData_ = std::get(*joinData); + if (CudfConfig::getInstance().debugEnabled) { + LOG(INFO) << "CudfHashJoinProbe received sort-merge join data" + << std::endl; + } + } buildStream_ = cudfJoinBridge->getBuildStream(); - // Lazy initialize matched flags only when build side is done - if (joinNode_->isRightJoin()) { + // Lazy initialize matched flags only when build side is done (hash join only) + if (joinNode_->isRightJoin() && hashObject_.has_value()) { auto& rightTablesInit = hashObject_.value().first; rightMatchedFlags_.clear(); rightMatchedFlags_.reserve(rightTablesInit.size()); @@ -1288,10 +1587,18 @@ exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { } initStream.synchronize(); } - auto& rightTables = hashObject_.value().first; + + // Get right tables for empty build check + const std::vector>* rightTables = nullptr; + if (hashObject_.has_value()) { + rightTables = &hashObject_.value().first; + } else if (sortMergeJoinData_.has_value()) { + rightTables = &sortMergeJoinData_.value().first; + } + // should be rightTable->numDistinct() but it needs compute, // so we use num_rows() - if (rightTables[0]->num_rows() == 0) { + if (rightTables && (*rightTables)[0]->num_rows() == 0) { if (skipProbeOnEmptyBuild()) { if (operatorCtx_->driverCtx() ->queryConfig() @@ -1313,9 +1620,10 @@ exec::BlockingReason CudfHashJoinProbe::isBlocked(ContinueFuture* future) { bool CudfHashJoinProbe::isFinished() { auto const isFinished = finished_ || (noMoreInput_ && input_ == nullptr); - // Release hashObject_ if finished + // Release join data if finished if (isFinished) { hashObject_.reset(); + sortMergeJoinData_.reset(); } return isFinished; } diff --git a/velox/experimental/cudf/exec/CudfHashJoin.h b/velox/experimental/cudf/exec/CudfHashJoin.h index 2606f9c1642..3bb165d1055 100644 --- a/velox/experimental/cudf/exec/CudfHashJoin.h +++ b/velox/experimental/cudf/exec/CudfHashJoin.h @@ -27,11 +27,13 @@ #include #include #include +#include #include #include #include +#include namespace facebook::velox::cudf_velox { @@ -52,6 +54,9 @@ class CudaEvent; */ class CudfHashJoinBridge : public exec::JoinBridge { public: + /// Join strategy discriminator + enum class JoinStrategy { kHashJoin, kSortMergeJoin }; + // The bridge transfers all build side batches and the hash join objects // constructed from them to the probe operator /** @brief Hash tables paired with their corresponding join objects for @@ -60,10 +65,24 @@ class CudfHashJoinBridge : public exec::JoinBridge { std::vector>, std::vector>>; + /** @brief Sort-merge join type with same structure for batching support */ + using sort_merge_join_type = std::pair< + std::vector>, + std::vector>>; + + /** @brief Union type supporting both hash join and sort-merge join */ + using join_type = std::variant; + void setHashTable(std::optional hashObject); std::optional hashOrFuture(ContinueFuture* future); + /// Set join data (supports both hash join and sort-merge join) + void setJoinData(std::optional joinData); + + /// Get join data or register for future notification + std::optional joinDataOrFuture(ContinueFuture* future); + // Store and retrieve the CUDA stream used for building the hash join. void setBuildStream(rmm::cuda_stream_view buildStream); @@ -73,6 +92,8 @@ class CudfHashJoinBridge : public exec::JoinBridge { /** @brief Hash tables and join objects transferred from build to probe * operators */ std::optional hashObject_; + /** @brief Join data (hash or sort-merge) transferred from build to probe */ + std::optional joinData_; /** @brief CUDA stream used by build operator for proper synchronization */ std::optional buildStream_; }; @@ -107,6 +128,25 @@ class CudfHashJoinBuild : public exec::Operator, public NvtxHelper { bool isFinished() override; private: + /// Determines whether to use hash join or sort-merge join based on + /// cardinality ratio of build table keys + CudfHashJoinBridge::JoinStrategy determineJoinStrategy( + const std::vector>& tbls, + const std::vector& buildKeyIndices, + rmm::cuda_stream_view stream); + + /// Build hash join objects and set them via the bridge + void buildHashJoin( + std::vector>&& tbls, + const std::vector& buildKeyIndices, + rmm::cuda_stream_view stream); + + /// Build sort-merge join objects and set them via the bridge + void buildSortMergeJoin( + std::vector>&& tbls, + const std::vector& buildKeyIndices, + rmm::cuda_stream_view stream); + std::shared_ptr joinNode_; std::vector inputs_; ContinueFuture future_{ContinueFuture::makeEmpty()}; @@ -126,6 +166,7 @@ class CudfHashJoinBuild : public exec::Operator, public NvtxHelper { class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { public: using hash_type = CudfHashJoinBridge::hash_type; + using sort_merge_join_type = CudfHashJoinBridge::sort_merge_join_type; CudfHashJoinProbe( int32_t operatorId, @@ -159,6 +200,8 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { std::shared_ptr joinNode_; /** @brief Hash tables and join objects received from build operator */ std::optional hashObject_; + /** @brief Sort-merge join objects received from build operator */ + std::optional sortMergeJoinData_; // Filter related members /** @brief CUDF AST tree for join filter evaluation */ @@ -310,6 +353,28 @@ class CudfHashJoinProbe : public exec::Operator, public NvtxHelper { cudf::column_view rightIndicesCol, cudf::join_kind joinKind, rmm::cuda_stream_view stream); + + /** + * @brief Performs sort-merge inner join between probe table and all build + * tables. + * @param leftTableView Probe-side table view to join + * @param stream CUDA stream for operations + * @return Vector of result tables (multiple if build data was batched) + */ + std::vector> innerJoinSortMerge( + cudf::table_view leftTableView, + rmm::cuda_stream_view stream); + + /** + * @brief Performs sort-merge left join between probe table and all build + * tables. + * @param leftTableView Probe-side table view to join + * @param stream CUDA stream for operations + * @return Vector of result tables (multiple if build data was batched) + */ + std::vector> leftJoinSortMerge( + cudf::table_view leftTableView, + rmm::cuda_stream_view stream); }; /** diff --git a/velox/experimental/cudf/exec/ToCudf.cpp b/velox/experimental/cudf/exec/ToCudf.cpp index 2ef433f6d1b..774619adb68 100644 --- a/velox/experimental/cudf/exec/ToCudf.cpp +++ b/velox/experimental/cudf/exec/ToCudf.cpp @@ -556,6 +556,10 @@ void CudfConfig::initialize( if (config.find(kCudfLogFallback) != config.end()) { logFallback = folly::to(config[kCudfLogFallback]); } + if (config.find(kCudfSortMergeJoinCardinalityThreshold) != config.end()) { + sortMergeJoinCardinalityThreshold = + folly::to(config[kCudfSortMergeJoinCardinalityThreshold]); + } } } // namespace facebook::velox::cudf_velox diff --git a/velox/experimental/cudf/exec/Utilities.cpp b/velox/experimental/cudf/exec/Utilities.cpp index 90aa838671a..9c7074f1188 100644 --- a/velox/experimental/cudf/exec/Utilities.cpp +++ b/velox/experimental/cudf/exec/Utilities.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -261,6 +262,23 @@ std::vector> getConcatenatedTableBatched( return outputTables; } +std::size_t estimateDistinctCount( + cudf::table_view table, + const std::vector& keyIndices, + rmm::cuda_stream_view stream) { + auto keyTable = table.select(keyIndices); + // Precision 12 gives ~1.6% standard error, good balance of accuracy and + // memory + constexpr int32_t kPrecision = 12; + auto estimator = cudf::approx_distinct_count( + keyTable, + kPrecision, + cudf::null_policy::EXCLUDE, + cudf::nan_policy::NAN_IS_NULL, + stream); + return estimator.estimate(stream); +} + CudaEvent::CudaEvent(unsigned int flags) { cudaEvent_t ev{}; cudaEventCreateWithFlags(&ev, flags); diff --git a/velox/experimental/cudf/exec/Utilities.h b/velox/experimental/cudf/exec/Utilities.h index 24f201e586c..8cebcdb6c9e 100644 --- a/velox/experimental/cudf/exec/Utilities.h +++ b/velox/experimental/cudf/exec/Utilities.h @@ -111,6 +111,22 @@ getConcatenatedTableBatched( * // ... launch kernels on stream2 (will wait for stream1 to reach event) ... * @endcode */ +/** + * @brief Estimates the number of distinct values in the specified key columns. + * + * Uses HyperLogLog algorithm via cudf::approx_distinct_count with precision 12, + * providing approximately 1.6% standard error in the estimate. + * + * @param table Table to analyze + * @param keyIndices Column indices for the key columns + * @param stream CUDA stream for operations + * @return Approximate count of distinct key combinations + */ +[[nodiscard]] std::size_t estimateDistinctCount( + cudf::table_view table, + const std::vector& keyIndices, + rmm::cuda_stream_view stream); + class CudaEvent { public: /** diff --git a/velox/experimental/cudf/tests/HashJoinTest.cpp b/velox/experimental/cudf/tests/HashJoinTest.cpp index d874b594a7b..972f2a38027 100644 --- a/velox/experimental/cudf/tests/HashJoinTest.cpp +++ b/velox/experimental/cudf/tests/HashJoinTest.cpp @@ -8494,4 +8494,164 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashTableCleanupAfterProbeFinish) { ASSERT_TRUE(tableEmpty); } +// Tests for Sort-Merge Join strategy selection based on cardinality + +TEST_F(HashJoinTest, sortMergeJoinInner) { + // Create build table with very low cardinality (few distinct keys) + // This should trigger sort-merge join strategy + auto buildVectors = std::vector{makeRowVector( + {"u_k1", "u_v1"}, + {makeFlatVector(1000, [](auto row) { return row % 5; }), + makeFlatVector(1000, [](auto row) { return row; })})}; + + auto probeVectors = std::vector{makeRowVector( + {"t_k1", "t_v1"}, + {makeFlatVector(500, [](auto row) { return row % 5; }), + makeFlatVector(500, [](auto row) { return row; })})}; + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + {"t_k1", "t_v1", "u_k1", "u_v1"}, + core::JoinType::kInner) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .numDrivers(1) + .injectSpill(false) + .referenceQuery( + "SELECT t_k1, t_v1, u_k1, u_v1 FROM t INNER JOIN u ON t_k1 = u_k1") + .run(); +} + +TEST_F(HashJoinTest, sortMergeJoinLeft) { + // Create build table with very low cardinality for left join + auto buildVectors = std::vector{makeRowVector( + {"u_k1", "u_v1"}, + {makeFlatVector(1000, [](auto row) { return row % 3; }), + makeFlatVector(1000, [](auto row) { return row; })})}; + + auto probeVectors = std::vector{makeRowVector( + {"t_k1", "t_v1"}, + {makeFlatVector(500, [](auto row) { return row % 10; }), + makeFlatVector(500, [](auto row) { return row; })})}; + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + {"t_k1", "t_v1", "u_k1", "u_v1"}, + core::JoinType::kLeft) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .numDrivers(1) + .injectSpill(false) + .referenceQuery( + "SELECT t_k1, t_v1, u_k1, u_v1 FROM t LEFT JOIN u ON t_k1 = u_k1") + .run(); +} + +TEST_F(HashJoinTest, sortMergeJoinHighCardinalityUsesHashJoin) { + // Create build table with high cardinality (unique keys) + // This should use hash join strategy instead of sort-merge join + auto buildVectors = std::vector{makeRowVector( + {"u_k1", "u_v1"}, + {makeFlatVector(1000, [](auto row) { return row; }), + makeFlatVector(1000, [](auto row) { return row; })})}; + + auto probeVectors = std::vector{makeRowVector( + {"t_k1", "t_v1"}, + {makeFlatVector(500, [](auto row) { return row; }), + makeFlatVector(500, [](auto row) { return row; })})}; + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + {"t_k1", "t_v1", "u_k1", "u_v1"}, + core::JoinType::kInner) + .planNode(); + + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .numDrivers(1) + .injectSpill(false) + .referenceQuery( + "SELECT t_k1, t_v1, u_k1, u_v1 FROM t INNER JOIN u ON t_k1 = u_k1") + .run(); +} + +TEST_F(HashJoinTest, sortMergeJoinThresholdConfig) { + // Test that the cardinality threshold configuration affects strategy + // selection. Set threshold to 0 to force hash join even with low cardinality. + auto buildVectors = std::vector{makeRowVector( + {"u_k1", "u_v1"}, + {makeFlatVector(1000, [](auto row) { return row % 5; }), + makeFlatVector(1000, [](auto row) { return row; })})}; + + auto probeVectors = std::vector{makeRowVector( + {"t_k1", "t_v1"}, + {makeFlatVector(500, [](auto row) { return row % 5; }), + makeFlatVector(500, [](auto row) { return row; })})}; + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(probeVectors, true) + .hashJoin( + {"t_k1"}, + {"u_k1"}, + PlanBuilder(planNodeIdGenerator) + .values(buildVectors, true) + .planNode(), + "", + {"t_k1", "t_v1", "u_k1", "u_v1"}, + core::JoinType::kInner) + .planNode(); + + // Set threshold to 0 to force hash join + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .numDrivers(1) + .injectSpill(false) + .config( + cudf_velox::CudfConfig::kCudfSortMergeJoinCardinalityThreshold, "0.0") + .referenceQuery( + "SELECT t_k1, t_v1, u_k1, u_v1 FROM t INNER JOIN u ON t_k1 = u_k1") + .run(); +} + } // namespace