@@ -22,9 +22,10 @@ import org.apache.spark.sql.Column
2222import org .apache .spark .sql .catalyst .InternalRow
2323import org .apache .spark .sql .catalyst .analysis .TypeCheckResult
2424import org .apache .spark .sql .catalyst .expressions .aggregate .{Collect , DeclarativeAggregate , ImperativeAggregate }
25- import org .apache .spark .sql .catalyst .expressions .{AttributeReference , BinaryComparison , ExpectsInputTypes , Expression , GreaterThan , If , IsNull , LessThan , Literal }
25+ import org .apache .spark .sql .catalyst .expressions .{AttributeReference , BinaryComparison , Concat , CreateArray , ExpectsInputTypes , Expression , GreaterThan , If , IsNull , LessThan , Literal , Size , Slice , SortArray }
2626import org .apache .spark .sql .catalyst .util .{GenericArrayData , TypeUtils }
27- import org .apache .spark .sql .types .{AbstractDataType , AnyDataType , DataType }
27+ import org .apache .spark .sql .functions .lit
28+ import org .apache .spark .sql .types .{AbstractDataType , AnyDataType , ArrayType , DataType }
2829
2930import scala .collection .generic .Growable
3031import scala .collection .mutable
@@ -39,10 +40,16 @@ In order to support Spark 3.1.x as well as Spark 3.2.0 and up, the methods withN
3940object SparkOverwriteUDAFs {
4041 def minValueByKey (key : Column , value : Column ): Column =
4142 Column (MinValueByKey (key.expr, value.expr).toAggregateExpression(false ))
43+
4244 def maxValueByKey (key : Column , value : Column ): Column =
4345 Column (MaxValueByKey (key.expr, value.expr).toAggregateExpression(false ))
46+
4447 def collectLimitedList (e : Column , maxSize : Int ): Column =
4548 Column (CollectLimitedList (e.expr, howMuchToTake = maxSize).toAggregateExpression(false ))
49+
50+ def collectNumberOrderedElements (col : Column , howManyToTake : Int , ascending : Boolean = false ) =
51+ Column (CollectNumberOrderedElements (col.expr, lit(howManyToTake).expr, lit(ascending).expr).toAggregateExpression(false ))
52+
4653}
4754
4855case class MinValueByKey (child1 : Expression , child2 : Expression )
@@ -86,7 +93,7 @@ abstract class ExtramumValueByKey(
8693 private lazy val data = AttributeReference (" data" , child2.dataType)()
8794
8895 override lazy val aggBufferAttributes
89- : Seq [AttributeReference ] = minmax :: data :: Nil
96+ : Seq [AttributeReference ] = minmax :: data :: Nil
9097
9198 override lazy val initialValues : Seq [Expression ] = Seq (
9299 Literal .create(null , child1.dataType),
@@ -174,3 +181,48 @@ abstract class LimitedCollect[T <: Growable[Any] with Iterable[Any]](howMuchToTa
174181 }
175182 }
176183}
184+
185+ case class CollectNumberOrderedElements (child : Expression , howManyToTake : Expression , ascending : Expression ) extends DeclarativeAggregate with ExpectsInputTypes {
186+
187+ override def children : Seq [Expression ] = Seq (child)
188+
189+ override def nullable : Boolean = true
190+
191+ // Return data type.
192+ override def dataType : DataType = ArrayType (child.dataType, containsNull = false )
193+
194+ // Expected input data type.
195+ override def inputTypes : Seq [AbstractDataType ] = Seq (AnyDataType )
196+
197+ override def checkInputDataTypes (): TypeCheckResult =
198+ TypeUtils .checkForOrderingExpr(child.dataType, " function TakeFirstValues" )
199+
200+ private lazy val data = AttributeReference (" data" , ArrayType (child.dataType, containsNull = false ))()
201+
202+ override lazy val aggBufferAttributes : Seq [AttributeReference ] = data :: Nil
203+
204+ override lazy val initialValues : Seq [Expression ] = Seq (
205+ Literal .create(Array (), ArrayType (child.dataType, containsNull = false ))
206+ )
207+
208+ // Change to array_append after Spark 3.4.0
209+ override lazy val updateExpressions : Seq [Expression ] = sortAndSliceArray(data, CreateArray (Seq (child)))
210+
211+ override lazy val mergeExpressions : Seq [Expression ] = sortAndSliceArray(data.right, data.left)
212+
213+ private def sortAndSliceArray (firstArray : Expression , secondArray : Expression ) = {
214+ val unifiedArray = Concat (Seq (firstArray, secondArray))
215+ Seq (
216+ If (GreaterThan (Size (unifiedArray), howManyToTake),
217+ Slice (SortArray (unifiedArray, ascending), Literal (1 ), howManyToTake),
218+ unifiedArray))
219+ }
220+
221+ override lazy val evaluateExpression : Expression = {
222+ SortArray (data, ascending)
223+ }
224+
225+ protected def withNewChildrenInternal (newChildren : IndexedSeq [Expression ]): Expression = {
226+ copy(child = newChildren.head)
227+ }
228+ }
0 commit comments