diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarCollectLimitExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarCollectLimitExec.scala index a3ee421e59a5..55d5a8628db8 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarCollectLimitExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarCollectLimitExec.scala @@ -19,6 +19,7 @@ package org.apache.gluten.execution import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.columnarbatch.VeloxColumnarBatches +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.vectorized.ColumnarBatch @@ -96,6 +97,28 @@ case class ColumnarCollectLimitExec( } } + override def executeCollect(): Array[InternalRow] = { + val inputBatches = + if (limit >= 0) { + child.executeColumnar() + } else { + executeColumnar() + } + val rowsRdd = inputBatches.mapPartitions { + it => + val rows = VeloxColumnarToRowExec.toRowIterator(it) + rows.map(_.copy()) + } + if (limit >= 0) { + val toTake = math.max(0, offset) + limit + val taken = rowsRdd.take(toTake) + if (offset > 0) taken.drop(offset) else taken + } else { + val all = rowsRdd.collect() + if (offset > 0) all.drop(offset) else all + } + } + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala index 3d3f4445c5bb..35e41a4e5384 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala @@ -60,4 +60,13 @@ abstract class ColumnarToRowExecBase(child: SparkPlan) override def doExecute(): RDD[InternalRow] = { doExecuteInternal() } + + override def executeCollect(): Array[InternalRow] = { + child match { + case l: ColumnarCollectLimitBaseExec => + l.executeCollect() + case _ => + super.executeCollect() + } + } }