@@ -30,6 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
3030use super :: {
3131 DisplayAs , ExecutionPlanProperties , RecordBatchStream , SendableRecordBatchStream ,
3232} ;
33+ use crate :: coalesce:: LimitedBatchCoalescer ;
3334use crate :: execution_plan:: { CardinalityEffect , EvaluationType , SchedulingType } ;
3435use crate :: hash_utils:: create_hashes;
3536use crate :: metrics:: { BaselineMetrics , SpillMetrics } ;
@@ -62,7 +63,7 @@ use crate::filter_pushdown::{
6263} ;
6364use datafusion_physical_expr_common:: utils:: evaluate_expressions_to_arrays;
6465use futures:: stream:: Stream ;
65- use futures:: { FutureExt , StreamExt , TryStreamExt } ;
66+ use futures:: { ready , FutureExt , StreamExt , TryStreamExt } ;
6667use log:: trace;
6768use parking_lot:: Mutex ;
6869
@@ -932,6 +933,7 @@ impl ExecutionPlan for RepartitionExec {
932933 spill_stream,
933934 1 , // Each receiver handles one input partition
934935 BaselineMetrics :: new ( & metrics, partition) ,
936+ None , // subsequent merge sort already does batching https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L286
935937 ) ) as SendableRecordBatchStream
936938 } )
937939 . collect :: < Vec < _ > > ( ) ;
@@ -970,6 +972,7 @@ impl ExecutionPlan for RepartitionExec {
970972 spill_stream,
971973 num_input_partitions,
972974 BaselineMetrics :: new ( & metrics, partition) ,
975+ Some ( context. session_config ( ) . batch_size ( ) ) ,
973976 ) ) as SendableRecordBatchStream )
974977 }
975978 } )
@@ -1427,9 +1430,13 @@ struct PerPartitionStream {
14271430
14281431 /// Execution metrics
14291432 baseline_metrics : BaselineMetrics ,
1433+
1434+ /// None for sort preserving variant (merge sort already does coalescing)
1435+ batch_coalescer : Option < LimitedBatchCoalescer > ,
14301436}
14311437
14321438impl PerPartitionStream {
1439+ #[ expect( clippy:: too_many_arguments) ]
14331440 fn new (
14341441 schema : SchemaRef ,
14351442 receiver : DistributionReceiver < MaybeBatch > ,
@@ -1438,7 +1445,10 @@ impl PerPartitionStream {
14381445 spill_stream : SendableRecordBatchStream ,
14391446 num_input_partitions : usize ,
14401447 baseline_metrics : BaselineMetrics ,
1448+ batch_size : Option < usize > ,
14411449 ) -> Self {
1450+ let batch_coalescer =
1451+ batch_size. map ( |s| LimitedBatchCoalescer :: new ( Arc :: clone ( & schema) , s, None ) ) ;
14421452 Self {
14431453 schema,
14441454 receiver,
@@ -1448,6 +1458,7 @@ impl PerPartitionStream {
14481458 state : StreamState :: ReadingMemory ,
14491459 remaining_partitions : num_input_partitions,
14501460 baseline_metrics,
1461+ batch_coalescer,
14511462 }
14521463 }
14531464
@@ -1531,6 +1542,43 @@ impl PerPartitionStream {
15311542 }
15321543 }
15331544 }
1545+
1546+ fn poll_next_and_coalesce (
1547+ self : & mut Pin < & mut Self > ,
1548+ cx : & mut Context < ' _ > ,
1549+ coalescer : & mut LimitedBatchCoalescer ,
1550+ ) -> Poll < Option < Result < RecordBatch > > > {
1551+ let cloned_time = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
1552+ let mut completed = false ;
1553+
1554+ loop {
1555+ if let Some ( batch) = coalescer. next_completed_batch ( ) {
1556+ return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1557+ }
1558+ if completed {
1559+ return Poll :: Ready ( None ) ;
1560+ }
1561+
1562+ match ready ! ( self . poll_next_inner( cx) ) {
1563+ Some ( Ok ( batch) ) => {
1564+ let _timer = cloned_time. timer ( ) ;
1565+ if let Err ( err) = coalescer. push_batch ( batch) {
1566+ return Poll :: Ready ( Some ( Err ( err) ) ) ;
1567+ }
1568+ }
1569+ Some ( err) => {
1570+ return Poll :: Ready ( Some ( err) ) ;
1571+ }
1572+ None => {
1573+ completed = true ;
1574+ let _timer = cloned_time. timer ( ) ;
1575+ if let Err ( err) = coalescer. finish ( ) {
1576+ return Poll :: Ready ( Some ( Err ( err) ) ) ;
1577+ }
1578+ }
1579+ }
1580+ }
1581+ }
15341582}
15351583
15361584impl Stream for PerPartitionStream {
@@ -1540,7 +1588,13 @@ impl Stream for PerPartitionStream {
15401588 mut self : Pin < & mut Self > ,
15411589 cx : & mut Context < ' _ > ,
15421590 ) -> Poll < Option < Self :: Item > > {
1543- let poll = self . poll_next_inner ( cx) ;
1591+ let poll;
1592+ if let Some ( mut coalescer) = self . batch_coalescer . take ( ) {
1593+ poll = self . poll_next_and_coalesce ( cx, & mut coalescer) ;
1594+ self . batch_coalescer = Some ( coalescer) ;
1595+ } else {
1596+ poll = self . poll_next_inner ( cx) ;
1597+ }
15441598 self . baseline_metrics . record_poll ( poll)
15451599 }
15461600}
@@ -1575,9 +1629,9 @@ mod tests {
15751629 use datafusion_common:: exec_err;
15761630 use datafusion_common:: test_util:: batches_to_sort_string;
15771631 use datafusion_common_runtime:: JoinSet ;
1632+ use datafusion_execution:: config:: SessionConfig ;
15781633 use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
15791634 use insta:: assert_snapshot;
1580- use itertools:: Itertools ;
15811635
15821636 #[ tokio:: test]
15831637 async fn one_to_many_round_robin ( ) -> Result < ( ) > {
@@ -1591,10 +1645,13 @@ mod tests {
15911645 repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 4 ) ) . await ?;
15921646
15931647 assert_eq ! ( 4 , output_partitions. len( ) ) ;
1594- assert_eq ! ( 13 , output_partitions[ 0 ] . len( ) ) ;
1595- assert_eq ! ( 13 , output_partitions[ 1 ] . len( ) ) ;
1596- assert_eq ! ( 12 , output_partitions[ 2 ] . len( ) ) ;
1597- assert_eq ! ( 12 , output_partitions[ 3 ] . len( ) ) ;
1648+ for partition in & output_partitions {
1649+ assert_eq ! ( 1 , partition. len( ) ) ;
1650+ }
1651+ assert_eq ! ( 13 * 8 , output_partitions[ 0 ] [ 0 ] . num_rows( ) ) ;
1652+ assert_eq ! ( 13 * 8 , output_partitions[ 1 ] [ 0 ] . num_rows( ) ) ;
1653+ assert_eq ! ( 12 * 8 , output_partitions[ 2 ] [ 0 ] . num_rows( ) ) ;
1654+ assert_eq ! ( 12 * 8 , output_partitions[ 3 ] [ 0 ] . num_rows( ) ) ;
15981655
15991656 Ok ( ( ) )
16001657 }
@@ -1611,7 +1668,7 @@ mod tests {
16111668 repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 1 ) ) . await ?;
16121669
16131670 assert_eq ! ( 1 , output_partitions. len( ) ) ;
1614- assert_eq ! ( 150 , output_partitions[ 0 ] . len ( ) ) ;
1671+ assert_eq ! ( 150 * 8 , output_partitions[ 0 ] [ 0 ] . num_rows ( ) ) ;
16151672
16161673 Ok ( ( ) )
16171674 }
@@ -1627,12 +1684,12 @@ mod tests {
16271684 let output_partitions =
16281685 repartition ( & schema, partitions, Partitioning :: RoundRobinBatch ( 5 ) ) . await ?;
16291686
1687+ let total_rows_per_partition = 8 * 50 * 3 / 5 ;
16301688 assert_eq ! ( 5 , output_partitions. len( ) ) ;
1631- assert_eq ! ( 30 , output_partitions[ 0 ] . len( ) ) ;
1632- assert_eq ! ( 30 , output_partitions[ 1 ] . len( ) ) ;
1633- assert_eq ! ( 30 , output_partitions[ 2 ] . len( ) ) ;
1634- assert_eq ! ( 30 , output_partitions[ 3 ] . len( ) ) ;
1635- assert_eq ! ( 30 , output_partitions[ 4 ] . len( ) ) ;
1689+ for partition in output_partitions {
1690+ assert_eq ! ( 1 , partition. len( ) ) ;
1691+ assert_eq ! ( total_rows_per_partition, partition[ 0 ] . num_rows( ) ) ;
1692+ }
16361693
16371694 Ok ( ( ) )
16381695 }
@@ -1662,6 +1719,32 @@ mod tests {
16621719 Ok ( ( ) )
16631720 }
16641721
1722+ #[ tokio:: test]
1723+ async fn test_repartition_with_coalescing ( ) -> Result < ( ) > {
1724+ let schema = test_schema ( ) ;
1725+ // create 50 batches, each having 8 rows
1726+ let partition = create_vec_batches ( 50 ) ;
1727+ let partitions = vec ! [ partition. clone( ) , partition. clone( ) ] ;
1728+ let partitioning = Partitioning :: RoundRobinBatch ( 1 ) ;
1729+
1730+ let session_config = SessionConfig :: new ( ) . with_batch_size ( 200 ) ;
1731+ let task_ctx = TaskContext :: default ( ) . with_session_config ( session_config) ;
1732+ let task_ctx = Arc :: new ( task_ctx) ;
1733+
1734+ // create physical plan
1735+ let exec = TestMemoryExec :: try_new_exec ( & partitions, Arc :: clone ( & schema) , None ) ?;
1736+ let exec = RepartitionExec :: try_new ( exec, partitioning) ?;
1737+
1738+ for i in 0 ..exec. partitioning ( ) . partition_count ( ) {
1739+ let mut stream = exec. execute ( i, Arc :: clone ( & task_ctx) ) ?;
1740+ while let Some ( result) = stream. next ( ) . await {
1741+ let batch = result?;
1742+ assert_eq ! ( 200 , batch. num_rows( ) ) ;
1743+ }
1744+ }
1745+ Ok ( ( ) )
1746+ }
1747+
16651748 fn test_schema ( ) -> Arc < Schema > {
16661749 Arc :: new ( Schema :: new ( vec ! [ Field :: new( "c0" , DataType :: UInt32 , false ) ] ) )
16671750 }
@@ -1707,12 +1790,12 @@ mod tests {
17071790
17081791 let output_partitions = handle. join ( ) . await . unwrap ( ) . unwrap ( ) ;
17091792
1793+ let total_rows_per_partition = 8 * 50 * 3 / 5 ;
17101794 assert_eq ! ( 5 , output_partitions. len( ) ) ;
1711- assert_eq ! ( 30 , output_partitions[ 0 ] . len( ) ) ;
1712- assert_eq ! ( 30 , output_partitions[ 1 ] . len( ) ) ;
1713- assert_eq ! ( 30 , output_partitions[ 2 ] . len( ) ) ;
1714- assert_eq ! ( 30 , output_partitions[ 3 ] . len( ) ) ;
1715- assert_eq ! ( 30 , output_partitions[ 4 ] . len( ) ) ;
1795+ for partition in output_partitions {
1796+ assert_eq ! ( 1 , partition. len( ) ) ;
1797+ assert_eq ! ( total_rows_per_partition, partition[ 0 ] . num_rows( ) ) ;
1798+ }
17161799
17171800 Ok ( ( ) )
17181801 }
@@ -1950,14 +2033,13 @@ mod tests {
19502033 } ) ;
19512034 let batches_with_drop = crate :: common:: collect ( output_stream1) . await . unwrap ( ) ;
19522035
1953- fn sort ( batch : Vec < RecordBatch > ) -> Vec < RecordBatch > {
1954- batch
1955- . into_iter ( )
1956- . sorted_by_key ( |b| format ! ( "{b:?}" ) )
1957- . collect ( )
1958- }
1959-
1960- assert_eq ! ( sort( batches_without_drop) , sort( batches_with_drop) ) ;
2036+ let items_vec_with_drop = str_batches_to_vec ( & batches_with_drop) ;
2037+ let items_set_with_drop: HashSet < & str > =
2038+ items_vec_with_drop. iter ( ) . copied ( ) . collect ( ) ;
2039+ assert_eq ! (
2040+ items_set_with_drop. symmetric_difference( & items_set) . count( ) ,
2041+ 0
2042+ ) ;
19612043 }
19622044
19632045 fn str_batches_to_vec ( batches : & [ RecordBatch ] ) -> Vec < & str > {
0 commit comments