Skip to content

Commit f9a3ebb

Browse files
Rahamim, Beneyala
authored andcommitted
DATAFU-176 Add collectNumberOrderedElements
Signed-off-by: Eyal Allweil <eyal@apache.org>
1 parent 601ab12 commit f9a3ebb

File tree

2 files changed

+84
-3
lines changed

2 files changed

+84
-3
lines changed

datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ import org.apache.spark.sql.Column
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.{Collect, DeclarativeAggregate, ImperativeAggregate}
25-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, ExpectsInputTypes, Expression, GreaterThan, If, IsNull, LessThan, Literal}
25+
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, Concat, CreateArray, ExpectsInputTypes, Expression, GreaterThan, If, IsNull, LessThan, Literal, Size, Slice, SortArray}
2626
import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils}
27-
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
27+
import org.apache.spark.sql.functions.lit
28+
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, DataType}
2829

2930
import scala.collection.generic.Growable
3031
import scala.collection.mutable
@@ -39,10 +40,16 @@ In order to support Spark 3.1.x as well as Spark 3.2.0 and up, the methods withN
3940
object SparkOverwriteUDAFs {
4041
def minValueByKey(key: Column, value: Column): Column =
4142
Column(MinValueByKey(key.expr, value.expr).toAggregateExpression(false))
43+
4244
def maxValueByKey(key: Column, value: Column): Column =
4345
Column(MaxValueByKey(key.expr, value.expr).toAggregateExpression(false))
46+
4447
def collectLimitedList(e: Column, maxSize: Int): Column =
4548
Column(CollectLimitedList(e.expr, howMuchToTake = maxSize).toAggregateExpression(false))
49+
50+
def collectNumberOrderedElements(col: Column, howManyToTake: Int, ascending: Boolean = false) =
51+
Column(CollectNumberOrderedElements(col.expr, lit(howManyToTake).expr, lit(ascending).expr).toAggregateExpression(false))
52+
4653
}
4754

4855
case class MinValueByKey(child1: Expression, child2: Expression)
@@ -86,7 +93,7 @@ abstract class ExtramumValueByKey(
8693
private lazy val data = AttributeReference("data", child2.dataType)()
8794

8895
override lazy val aggBufferAttributes
89-
: Seq[AttributeReference] = minmax :: data :: Nil
96+
: Seq[AttributeReference] = minmax :: data :: Nil
9097

9198
override lazy val initialValues: Seq[Expression] = Seq(
9299
Literal.create(null, child1.dataType),
@@ -174,3 +181,48 @@ abstract class LimitedCollect[T <: Growable[Any] with Iterable[Any]](howMuchToTa
174181
}
175182
}
176183
}
184+
185+
case class CollectNumberOrderedElements(child: Expression, howManyToTake: Expression, ascending: Expression) extends DeclarativeAggregate with ExpectsInputTypes {
186+
187+
override def children: Seq[Expression] = Seq(child)
188+
189+
override def nullable: Boolean = true
190+
191+
// Return data type.
192+
override def dataType: DataType = ArrayType(child.dataType, containsNull = false)
193+
194+
// Expected input data type.
195+
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
196+
197+
override def checkInputDataTypes(): TypeCheckResult =
198+
TypeUtils.checkForOrderingExpr(child.dataType, "function TakeFirstValues")
199+
200+
private lazy val data = AttributeReference("data", ArrayType(child.dataType, containsNull = false))()
201+
202+
override lazy val aggBufferAttributes: Seq[AttributeReference] = data :: Nil
203+
204+
override lazy val initialValues: Seq[Expression] = Seq(
205+
Literal.create(Array(), ArrayType(child.dataType, containsNull = false))
206+
)
207+
208+
// Change to array_append after Spark 3.4.0
209+
override lazy val updateExpressions: Seq[Expression] = sortAndSliceArray(data, CreateArray(Seq(child)))
210+
211+
override lazy val mergeExpressions: Seq[Expression] = sortAndSliceArray(data.right, data.left)
212+
213+
private def sortAndSliceArray(firstArray: Expression, secondArray: Expression) = {
214+
val unifiedArray = Concat(Seq(firstArray, secondArray))
215+
Seq(
216+
If(GreaterThan(Size(unifiedArray), howManyToTake),
217+
Slice(SortArray(unifiedArray, ascending), Literal(1), howManyToTake),
218+
unifiedArray))
219+
}
220+
221+
override lazy val evaluateExpression: Expression = {
222+
SortArray(data, ascending)
223+
}
224+
225+
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
226+
copy(child = newChildren.head)
227+
}
228+
}

datafu-spark/src/test/scala/datafu/spark/TestSparkUDAFs.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,33 @@ class UdafTests extends FunSuite with DataFrameSuiteBase {
145145
Assert.assertEquals(0, rows_different.count())
146146

147147
}
148+
149+
test("test_collectNumberOrderedElements") {
150+
val nums = Seq(1,3,2,5,4,7,6,9,8)
151+
val rows = nums.flatMap(x => (1 to x).map(n => (x, "str" + n))).toDF("num", "str")
152+
153+
import org.apache.spark.sql.functions._
154+
155+
val result = rows.groupBy("num").agg(SparkOverwriteUDAFs.collectNumberOrderedElements(col("str"), 4, ascending = false).as("list"))
156+
157+
val schema = StructType(List(
158+
StructField("num", IntegerType, nullable = false),
159+
StructField("list", ArrayType(StringType, containsNull = false)
160+
)))
161+
162+
val expected = sqlContext.createDataFrame(
163+
sc.parallelize(Seq(
164+
Row(1, Seq("str1")),
165+
Row(2, Seq("str2", "str1")),
166+
Row(3, Seq("str3", "str2", "str1")),
167+
Row(4, Seq("str4", "str3", "str2", "str1")),
168+
Row(5, Seq("str5", "str4", "str3", "str2")),
169+
Row(6, Seq("str6", "str5", "str4", "str3")),
170+
Row(7, Seq("str7", "str6", "str5", "str4")),
171+
Row(8, Seq("str8", "str7", "str6", "str5")),
172+
Row(9, Seq("str9", "str8", "str7", "str6")),
173+
)), schema)
174+
175+
assertDataFrameNoOrderEquals(expected, result)
176+
}
148177
}

0 commit comments

Comments
 (0)