Skip to content

Commit 20870c1

Browse files
authored
feat: integrate batch coalescer with repartition exec (#19002)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #18782. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> `RepartitionExec` has two cases: sort-preserving and non sort-preserving. This change integrates `LimitedBatchCoalescer` with the latter. For the former, it seems that `SortPreservingMergeStream` that builds on top of `PerPartitionStreams` has batching logic built in: https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L279-L289 hence I did not include in this change. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> No
1 parent a30cf37 commit 20870c1

File tree

2 files changed

+120
-28
lines changed

2 files changed

+120
-28
lines changed

datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use datafusion::prelude::*;
3535
use datafusion::scalar::ScalarValue;
3636
use datafusion_catalog::Session;
3737
use datafusion_common::cast::as_primitive_array;
38-
use datafusion_common::{internal_err, not_impl_err};
38+
use datafusion_common::{internal_err, not_impl_err, DataFusionError};
3939
use datafusion_expr::expr::{BinaryExpr, Cast};
4040
use datafusion_functions_aggregate::expr_fn::count;
4141
use datafusion_physical_expr::EquivalenceProperties;
@@ -134,9 +134,19 @@ impl ExecutionPlan for CustomPlan {
134134
_partition: usize,
135135
_context: Arc<TaskContext>,
136136
) -> Result<SendableRecordBatchStream> {
137+
let schema_captured = self.schema().clone();
137138
Ok(Box::pin(RecordBatchStreamAdapter::new(
138139
self.schema(),
139-
futures::stream::iter(self.batches.clone().into_iter().map(Ok)),
140+
futures::stream::iter(self.batches.clone().into_iter().map(move |batch| {
141+
let projection: Vec<usize> = schema_captured
142+
.fields()
143+
.iter()
144+
.filter_map(|field| batch.schema().index_of(field.name()).ok())
145+
.collect();
146+
batch
147+
.project(&projection)
148+
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
149+
})),
140150
)))
141151
}
142152

datafusion/physical-plan/src/repartition/mod.rs

Lines changed: 108 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
3030
use super::{
3131
DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
3232
};
33+
use crate::coalesce::LimitedBatchCoalescer;
3334
use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
3435
use crate::hash_utils::create_hashes;
3536
use crate::metrics::{BaselineMetrics, SpillMetrics};
@@ -62,7 +63,7 @@ use crate::filter_pushdown::{
6263
};
6364
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
6465
use futures::stream::Stream;
65-
use futures::{FutureExt, StreamExt, TryStreamExt};
66+
use futures::{ready, FutureExt, StreamExt, TryStreamExt};
6667
use log::trace;
6768
use parking_lot::Mutex;
6869

@@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
932933
spill_stream,
933934
1, // Each receiver handles one input partition
934935
BaselineMetrics::new(&metrics, partition),
936+
None, // subsequent merge sort already does batching https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L286
935937
)) as SendableRecordBatchStream
936938
})
937939
.collect::<Vec<_>>();
@@ -970,6 +972,7 @@ impl ExecutionPlan for RepartitionExec {
970972
spill_stream,
971973
num_input_partitions,
972974
BaselineMetrics::new(&metrics, partition),
975+
Some(context.session_config().batch_size()),
973976
)) as SendableRecordBatchStream)
974977
}
975978
})
@@ -1427,9 +1430,13 @@ struct PerPartitionStream {
14271430

14281431
/// Execution metrics
14291432
baseline_metrics: BaselineMetrics,
1433+
1434+
/// None for sort preserving variant (merge sort already does coalescing)
1435+
batch_coalescer: Option<LimitedBatchCoalescer>,
14301436
}
14311437

14321438
impl PerPartitionStream {
1439+
#[expect(clippy::too_many_arguments)]
14331440
fn new(
14341441
schema: SchemaRef,
14351442
receiver: DistributionReceiver<MaybeBatch>,
@@ -1438,7 +1445,10 @@ impl PerPartitionStream {
14381445
spill_stream: SendableRecordBatchStream,
14391446
num_input_partitions: usize,
14401447
baseline_metrics: BaselineMetrics,
1448+
batch_size: Option<usize>,
14411449
) -> Self {
1450+
let batch_coalescer =
1451+
batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None));
14421452
Self {
14431453
schema,
14441454
receiver,
@@ -1448,6 +1458,7 @@ impl PerPartitionStream {
14481458
state: StreamState::ReadingMemory,
14491459
remaining_partitions: num_input_partitions,
14501460
baseline_metrics,
1461+
batch_coalescer,
14511462
}
14521463
}
14531464

@@ -1531,6 +1542,43 @@ impl PerPartitionStream {
15311542
}
15321543
}
15331544
}
1545+
1546+
fn poll_next_and_coalesce(
1547+
self: &mut Pin<&mut Self>,
1548+
cx: &mut Context<'_>,
1549+
coalescer: &mut LimitedBatchCoalescer,
1550+
) -> Poll<Option<Result<RecordBatch>>> {
1551+
let cloned_time = self.baseline_metrics.elapsed_compute().clone();
1552+
let mut completed = false;
1553+
1554+
loop {
1555+
if let Some(batch) = coalescer.next_completed_batch() {
1556+
return Poll::Ready(Some(Ok(batch)));
1557+
}
1558+
if completed {
1559+
return Poll::Ready(None);
1560+
}
1561+
1562+
match ready!(self.poll_next_inner(cx)) {
1563+
Some(Ok(batch)) => {
1564+
let _timer = cloned_time.timer();
1565+
if let Err(err) = coalescer.push_batch(batch) {
1566+
return Poll::Ready(Some(Err(err)));
1567+
}
1568+
}
1569+
Some(err) => {
1570+
return Poll::Ready(Some(err));
1571+
}
1572+
None => {
1573+
completed = true;
1574+
let _timer = cloned_time.timer();
1575+
if let Err(err) = coalescer.finish() {
1576+
return Poll::Ready(Some(Err(err)));
1577+
}
1578+
}
1579+
}
1580+
}
1581+
}
15341582
}
15351583

15361584
impl Stream for PerPartitionStream {
@@ -1540,7 +1588,13 @@ impl Stream for PerPartitionStream {
15401588
mut self: Pin<&mut Self>,
15411589
cx: &mut Context<'_>,
15421590
) -> Poll<Option<Self::Item>> {
1543-
let poll = self.poll_next_inner(cx);
1591+
let poll;
1592+
if let Some(mut coalescer) = self.batch_coalescer.take() {
1593+
poll = self.poll_next_and_coalesce(cx, &mut coalescer);
1594+
self.batch_coalescer = Some(coalescer);
1595+
} else {
1596+
poll = self.poll_next_inner(cx);
1597+
}
15441598
self.baseline_metrics.record_poll(poll)
15451599
}
15461600
}
@@ -1575,9 +1629,9 @@ mod tests {
15751629
use datafusion_common::exec_err;
15761630
use datafusion_common::test_util::batches_to_sort_string;
15771631
use datafusion_common_runtime::JoinSet;
1632+
use datafusion_execution::config::SessionConfig;
15781633
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
15791634
use insta::assert_snapshot;
1580-
use itertools::Itertools;
15811635

15821636
#[tokio::test]
15831637
async fn one_to_many_round_robin() -> Result<()> {
@@ -1591,10 +1645,13 @@ mod tests {
15911645
repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
15921646

15931647
assert_eq!(4, output_partitions.len());
1594-
assert_eq!(13, output_partitions[0].len());
1595-
assert_eq!(13, output_partitions[1].len());
1596-
assert_eq!(12, output_partitions[2].len());
1597-
assert_eq!(12, output_partitions[3].len());
1648+
for partition in &output_partitions {
1649+
assert_eq!(1, partition.len());
1650+
}
1651+
assert_eq!(13 * 8, output_partitions[0][0].num_rows());
1652+
assert_eq!(13 * 8, output_partitions[1][0].num_rows());
1653+
assert_eq!(12 * 8, output_partitions[2][0].num_rows());
1654+
assert_eq!(12 * 8, output_partitions[3][0].num_rows());
15981655

15991656
Ok(())
16001657
}
@@ -1611,7 +1668,7 @@ mod tests {
16111668
repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
16121669

16131670
assert_eq!(1, output_partitions.len());
1614-
assert_eq!(150, output_partitions[0].len());
1671+
assert_eq!(150 * 8, output_partitions[0][0].num_rows());
16151672

16161673
Ok(())
16171674
}
@@ -1627,12 +1684,12 @@ mod tests {
16271684
let output_partitions =
16281685
repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
16291686

1687+
let total_rows_per_partition = 8 * 50 * 3 / 5;
16301688
assert_eq!(5, output_partitions.len());
1631-
assert_eq!(30, output_partitions[0].len());
1632-
assert_eq!(30, output_partitions[1].len());
1633-
assert_eq!(30, output_partitions[2].len());
1634-
assert_eq!(30, output_partitions[3].len());
1635-
assert_eq!(30, output_partitions[4].len());
1689+
for partition in output_partitions {
1690+
assert_eq!(1, partition.len());
1691+
assert_eq!(total_rows_per_partition, partition[0].num_rows());
1692+
}
16361693

16371694
Ok(())
16381695
}
@@ -1662,6 +1719,32 @@ mod tests {
16621719
Ok(())
16631720
}
16641721

1722+
#[tokio::test]
1723+
async fn test_repartition_with_coalescing() -> Result<()> {
1724+
let schema = test_schema();
1725+
// create 50 batches, each having 8 rows
1726+
let partition = create_vec_batches(50);
1727+
let partitions = vec![partition.clone(), partition.clone()];
1728+
let partitioning = Partitioning::RoundRobinBatch(1);
1729+
1730+
let session_config = SessionConfig::new().with_batch_size(200);
1731+
let task_ctx = TaskContext::default().with_session_config(session_config);
1732+
let task_ctx = Arc::new(task_ctx);
1733+
1734+
// create physical plan
1735+
let exec = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?;
1736+
let exec = RepartitionExec::try_new(exec, partitioning)?;
1737+
1738+
for i in 0..exec.partitioning().partition_count() {
1739+
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1740+
while let Some(result) = stream.next().await {
1741+
let batch = result?;
1742+
assert_eq!(200, batch.num_rows());
1743+
}
1744+
}
1745+
Ok(())
1746+
}
1747+
16651748
fn test_schema() -> Arc<Schema> {
16661749
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
16671750
}
@@ -1707,12 +1790,12 @@ mod tests {
17071790

17081791
let output_partitions = handle.join().await.unwrap().unwrap();
17091792

1793+
let total_rows_per_partition = 8 * 50 * 3 / 5;
17101794
assert_eq!(5, output_partitions.len());
1711-
assert_eq!(30, output_partitions[0].len());
1712-
assert_eq!(30, output_partitions[1].len());
1713-
assert_eq!(30, output_partitions[2].len());
1714-
assert_eq!(30, output_partitions[3].len());
1715-
assert_eq!(30, output_partitions[4].len());
1795+
for partition in output_partitions {
1796+
assert_eq!(1, partition.len());
1797+
assert_eq!(total_rows_per_partition, partition[0].num_rows());
1798+
}
17161799

17171800
Ok(())
17181801
}
@@ -1950,14 +2033,13 @@ mod tests {
19502033
});
19512034
let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
19522035

1953-
fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1954-
batch
1955-
.into_iter()
1956-
.sorted_by_key(|b| format!("{b:?}"))
1957-
.collect()
1958-
}
1959-
1960-
assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
2036+
let items_vec_with_drop = str_batches_to_vec(&batches_with_drop);
2037+
let items_set_with_drop: HashSet<&str> =
2038+
items_vec_with_drop.iter().copied().collect();
2039+
assert_eq!(
2040+
items_set_with_drop.symmetric_difference(&items_set).count(),
2041+
0
2042+
);
19612043
}
19622044

19632045
fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {

0 commit comments

Comments
 (0)