From 8998c6b61ff6a4146745c2881014a59c950fd898 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 21 Apr 2026 15:29:16 +0800 Subject: [PATCH 01/12] Add initial GpuArrayAggregate with SUM decomposition Implements ArrayAggregate on the GPU for lambdas decomposable as (acc, x) -> acc + g(x) with an identity finish. Other shapes fall back to the CPU. - ArrayAggregateDecomposer: match merge body against Add(acc, g), unwrap Cast on the acc side, validate finish is identity - GpuArrayAggregate: evaluate g(x) via the existing GpuArrayTransformBase explode path, then listReduce + combine with zero. Uses NullPolicy.INCLUDE so null elements poison the sum, matching Spark's iterative `acc + null = null` semantics. Empty (non-null) lists are substituted with op's identity before the add-zero step; null lists stay null and propagate. - Decimal identity scalar is bound to the column's DType (via Scalar.fromDecimal(BigInteger, DType)) so ifElse / add don't trip on DECIMAL32-vs-DECIMAL128 width mismatches. - Unit tests for the decomposer and integration tests covering the client pattern, null / empty arrays, non-zero init, outer-column refs, struct-field access, long overflow, decimal sum, and fallback cases. Addresses part of https://github.com/NVIDIA/spark-rapids/issues/8532. A follow-up refactor will introduce a normalize pass and AggOp trait to support PRODUCT / MIN / MAX / AND / OR and Cast stripping. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../python/higher_order_functions_test.py | 184 ++++++++++++- .../nvidia/spark/rapids/GpuOverrides.scala | 22 ++ .../spark/rapids/higherOrderFunctions.scala | 252 +++++++++++++++++- .../ArrayAggregateDecomposerSuite.scala | 183 +++++++++++++ tools/generated_files/352/operatorsScore.csv | 1 + tools/generated_files/352/supportedExprs.csv | 5 + 6 files changed, 641 insertions(+), 6 deletions(-) create mode 100644 tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 23d61793b46..3a7f72b6624 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from asserts import assert_gpu_and_cpu_are_equal_collect -from marks import ignore_order +import pytest + +from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect +from data_gen import * +from marks import allow_non_gpu, disable_ansi_mode, ignore_order @ignore_order(local=True) @@ -29,3 +32,178 @@ def do_project(spark): return df.selectExpr( "transform(c, (v, i) -> named_struct('x', c[i].x, 'y', c[i].y)) AS t") assert_gpu_and_cpu_are_equal_collect(do_project, conf=confs) + + +# --- ArrayAggregate tests --- +# +# Covers the decomposable SUM path: lambdas of the form `(acc, x) -> acc + g(x)` with an +# identity finish. Non-decomposable shapes must fall back to CPU. + +# Simple: sum elements of an array with zero = 0. Accumulator type is promoted to long +# via the CAST in g, matching the zero's LongType. +@disable_ansi_mode +def test_array_aggregate_sum_int_to_long(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=15)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + CAST(x as BIGINT)) as sum')) + + +# Count-if pattern: sum of CASE WHEN predicate THEN 1 ELSE 0 END. +# This is the structural twin of the client's real workload. +@disable_ansi_mode +def test_array_aggregate_count_if_int(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=15)).selectExpr( + 'aggregate(a, 0, (acc, x) -> acc + CASE WHEN x > 0 THEN 1 ELSE 0 END) as pos_cnt', + 'aggregate(a, 0L, (acc, x) -> acc + CAST(CASE WHEN x IS NULL THEN 1 ELSE 0 END as BIGINT)) as null_cnt')) + + +# Client's actual pattern (simplified to two fields): filter + aggregate with split / GetArrayItem / IN. +# The string data is synthesized with 4-space separators so each element has enough columns. +@disable_ansi_mode +def test_array_aggregate_client_pattern(): + # Generate strings like "aa bb cc dd ee" so the split yields >2 pieces. + field_gen = StringGen('[a-z]{2}') + # Build strings via concat_ws in SQL; use a simple array of ~5 strings. + def do_it(spark): + df = unary_op_df(spark, ArrayGen(field_gen, max_length=5)) + return df.selectExpr(""" + aggregate( + filter(transform(a, x -> concat_ws(' ', x, x, x, x, x)), z -> z != ''), + 0L, + (acc, z) -> acc + CAST(CASE WHEN ( + size(split(z, ' ', -1)) > 2 + AND split(z, ' ', -1)[2] IN ('aa', 'bb') + AND NOT split(z, ' ', -1)[1] IN ('xx', 'yy') + ) THEN 1 ELSE 0 END as BIGINT), + id -> id + ) as client_cnt""") + assert_gpu_and_cpu_are_equal_collect(do_it) + + +# Non-zero init: result should include the init. +@disable_ansi_mode +def test_array_aggregate_non_zero_init(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=10)).selectExpr( + 'aggregate(a, 100L, (acc, x) -> acc + CAST(x as BIGINT)) as sum_with_init')) + + +# Null and empty arrays. Spark semantics: null array -> null, empty array -> finish(init) = init. +@disable_ansi_mode +def test_array_aggregate_null_array(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, all_null=True)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + CAST(x as BIGINT)) as should_be_null')) + + +@disable_ansi_mode +def test_array_aggregate_empty_array(): + def do_it(spark): + # Array column with some empty arrays interspersed. + return spark.createDataFrame( + [([1, 2, 3],), ([],), ([7],), ([],)], + 'a array').selectExpr( + 'aggregate(a, 42L, (acc, x) -> acc + CAST(x as BIGINT)) as sum_with_empty') + assert_gpu_and_cpu_are_equal_collect(do_it) + + +# Non-decomposable lambda must fall back to CPU. `acc - x` is not associative / not in whitelist. +@disable_ansi_mode +@allow_non_gpu('ProjectExec') +def test_array_aggregate_subtract_falls_back(): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc - CAST(x as BIGINT)) as diff'), + 'ArrayAggregate') + + +# Non-identity finish must fall back. +@disable_ansi_mode +@allow_non_gpu('ProjectExec') +def test_array_aggregate_non_identity_finish_falls_back(): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + CAST(x as BIGINT), acc -> acc * 2) as doubled'), + 'ArrayAggregate') + + +# g that references the accumulator must fall back. +@disable_ansi_mode +@allow_non_gpu('ProjectExec') +def test_array_aggregate_g_references_acc_falls_back(): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + acc * CAST(x as BIGINT)) as recur'), + 'ArrayAggregate') + + +# Multiplicative accumulator (not in the SUM whitelist) must fall back. +@disable_ansi_mode +@allow_non_gpu('ProjectExec') +def test_array_aggregate_product_falls_back(): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( + 'aggregate(a, 1L, (acc, x) -> acc * CAST(x as BIGINT)) as prod'), + 'ArrayAggregate') + + +# Lambda body references an outer attribute ("b") — exercises boundIntermediate plumbing. +@disable_ansi_mode +def test_array_aggregate_lambda_refs_outer_column(): + def do_it(spark): + return two_col_df(spark, ArrayGen(int_gen, max_length=10), int_gen).selectExpr( + # g(x) = (x + b) — b is a closed-over outer column, not the acc. + 'aggregate(a, 0L, (acc, x) -> acc + CAST(x + b as BIGINT)) as sum_with_outer') + assert_gpu_and_cpu_are_equal_collect(do_it) + + +# zero is an outer column, not a literal. +@disable_ansi_mode +def test_array_aggregate_zero_is_outer_column(): + def do_it(spark): + return two_col_df(spark, ArrayGen(int_gen, max_length=10), long_gen).selectExpr( + 'aggregate(a, b, (acc, x) -> acc + CAST(x as BIGINT)) as sum_from_col') + assert_gpu_and_cpu_are_equal_collect(do_it) + + +# array: accumulate over a struct field. +@disable_ansi_mode +def test_array_aggregate_over_struct_field(): + def do_it(spark): + elem_gen = StructGen([['i', int_gen]], nullable=False) + return unary_op_df(spark, ArrayGen(elem_gen, max_length=10)).selectExpr( + 'aggregate(a, 0L, (acc, s) -> acc + CAST(s.i as BIGINT)) as sum_field') + assert_gpu_and_cpu_are_equal_collect(do_it) + + +# Deeper g body without acc references (x * 2 + 1). +@disable_ansi_mode +def test_array_aggregate_deeper_g_body(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=10)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + CAST(x * 2 + 1 as BIGINT)) as sum_poly')) + + +# Long-overflow wrap-around: in non-ANSI mode both Spark SUM and cudf SUM wrap silently. +@disable_ansi_mode +def test_array_aggregate_long_overflow_wraps(): + def do_it(spark): + big = LongGen(min_val=9223372036854775000, max_val=9223372036854775700, nullable=False) + return unary_op_df(spark, ArrayGen(big, min_length=5, max_length=15)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + x) as wrapped_sum') + assert_gpu_and_cpu_are_equal_collect(do_it) + + +# Decimal SUM: Spark's ArrayAggregate requires merge.dataType == zero.dataType exactly. For +# DECIMAL(10,2) + DECIMAL(10,2) the result is DECIMAL(11,2), which fails analysis against a +# DECIMAL(10,2) zero. The working pattern is to widen zero to DECIMAL(38,2) (Spark's cap) +# and Cast the element to match, so `acc + Cast(x, DECIMAL(38,2))` stays at DECIMAL(38,2). +@disable_ansi_mode +def test_array_aggregate_decimal_sum(): + decimal_gen = DecimalGen(precision=10, scale=2) + def do_it(spark): + return unary_op_df(spark, ArrayGen(decimal_gen, max_length=8)).selectExpr( + 'aggregate(a, CAST(0 as DECIMAL(38,2)), ' + '(acc, x) -> acc + CAST(x as DECIMAL(38,2))) as dec_sum') + assert_gpu_and_cpu_are_equal_collect(do_it) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 8790334e48a..5c6e40fc457 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2962,6 +2962,28 @@ object GpuOverrides extends Logging { ) } }), + expr[ArrayAggregate]( + "Aggregate elements in an array using an accumulator function and finishing " + + "transformation. Currently only lambdas of the form (acc, x) -> acc + g(x) with an " + + "identity finish are executed on the GPU; other shapes fall back to CPU.", + ExprChecks.projectOnly( + TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, + TypeSig.all, + Seq( + ParamCheck("argument", + TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + + TypeSig.STRUCT), + TypeSig.ARRAY.nested(TypeSig.all)), + ParamCheck("zero", + TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, + TypeSig.all), + ParamCheck("merge", + TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, + TypeSig.all), + ParamCheck("finish", + TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, + TypeSig.all))), + (in, conf, p, r) => new GpuArrayAggregateMeta(in, conf, p, r)), // TODO: fix the signature https://github.com/NVIDIA/spark-rapids/issues/5327 expr[ArraysZip]( "Returns a merged array of structs in which the N-th struct contains" + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index 22e21c3b125..0f59ccf9b0d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -26,9 +26,9 @@ import com.nvidia.spark.rapids.jni.GpuMapZipWithUtils import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, Expression, ExprId, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Add, ArrayAggregate, Attribute, AttributeReference, AttributeSeq, Cast, Expression, ExprId, LambdaFunction, NamedExpression, NamedLambdaVariable} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BooleanType, DataType, MapType, Metadata, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, ShortType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch /** @@ -222,7 +222,7 @@ trait GpuArrayTransformBase extends GpuSimpleHigherOrderFunction { boundIntermediate.map(_.dataType) ++ lambdaFunction.arguments.map(_.dataType) } - private[this] def makeElementProjectBatch( + protected def makeElementProjectBatch( inputBatch: ColumnarBatch, argColumn: GpuColumnVector): ColumnarBatch = { assert(argColumn.getBase.getType.equals(DType.LIST)) @@ -895,3 +895,249 @@ case class GpuMapFilter(argument: Expression, } } } + + +/** + * Decomposes a Spark ArrayAggregate's merge lambda of shape `(acc, x) -> op(acc, g(x))` into + * a map-reduce form executable with cuDF segmented-reduction APIs. Currently supports SUM. + */ +object ArrayAggregateDecomposer { + + sealed trait AggOp + case object SumOp extends AggOp + + /** + * @param op the segmented reduction aggregation operator + * @param gChildIndex 0 if `g` is the left child of the merge body's binary op, 1 if right + * @param accVarExprId the accumulator NamedLambdaVariable's exprId + * @param elemVar the element NamedLambdaVariable (used to build the g lambda) + */ + case class Decomposition( + op: AggOp, + gChildIndex: Int, + accVarExprId: ExprId, + elemVar: NamedLambdaVariable) + + def decompose(merge: Expression, finish: Expression): Option[Decomposition] = { + val mergeLambda = merge match { + case lf: LambdaFunction => lf + case _ => return None + } + val (accVar, elemVar) = mergeLambda.arguments match { + case Seq(a: NamedLambdaVariable, e: NamedLambdaVariable) => (a, e) + case _ => return None + } + if (!isFinishIdentity(finish)) return None + + mergeLambda.function match { + case add: Add => + val accId = accVar.exprId + if (isAccRef(add.left, accId) && !containsAccRef(add.right, accId)) { + Some(Decomposition(SumOp, 1, accId, elemVar)) + } else if (isAccRef(add.right, accId) && !containsAccRef(add.left, accId)) { + Some(Decomposition(SumOp, 0, accId, elemVar)) + } else { + None + } + case _ => None + } + } + + private def isFinishIdentity(finish: Expression): Boolean = finish match { + case LambdaFunction(body, Seq(accVar: NamedLambdaVariable), _) => + isAccRef(body, accVar.exprId) + case _ => false + } + + private def isAccRef(e: Expression, id: ExprId): Boolean = e match { + case v: NamedLambdaVariable => v.exprId == id + case c: Cast => isAccRef(c.child, id) + case _ => false + } + + private def containsAccRef(e: Expression, id: ExprId): Boolean = e.exists { + case v: NamedLambdaVariable if v.exprId == id => true + case _ => false + } +} + + +/** + * GPU implementation of ArrayAggregate restricted to lambdas decomposable via + * ArrayAggregateDecomposer. Runtime steps: + * 1. Evaluate g(x) over the array children (reusing GpuArrayTransformBase's explode path). + * 2. Rewrap as list with the original offsets and validity. + * 3. cuDF segmented reduce. + * 4. Replace null reduction results (from empty lists) with op's identity. + * 5. Combine with zero: result = zero op filled. + * 6. Restore null for rows where the input array was null. + */ +case class GpuArrayAggregate( + argument: Expression, + zero: Expression, + function: Expression, + op: ArrayAggregateDecomposer.AggOp, + isBound: Boolean = false, + boundIntermediate: Seq[GpuExpression] = Seq.empty) extends GpuArrayTransformBase { + + override def dataType: DataType = zero.dataType + + override def nullable: Boolean = argument.nullable + + override def prettyName: String = "array_aggregate" + + // Include zero as a child so analyzer / optimizer passes can see it. + override def children: Seq[Expression] = argument :: zero :: function :: Nil + + override def bind(input: AttributeSeq): GpuExpression = { + val (boundFunc, boundArg, boundInter) = bindLambdaFunc(input) + val boundZero = GpuBindReferences.bindGpuReferenceNoMetrics(zero, input) + GpuArrayAggregate(boundArg, boundZero, boundFunc, op, isBound = true, boundInter) + } + + // We override columnarEval entirely; the base class's template isn't a fit because + // the lambda output still needs a segmented reduction plus combine-with-zero. + override protected def transformListColumnView( + lambdaTransformedCV: cudf.ColumnView, + arg: cudf.ColumnView): GpuColumnVector = { + throw new IllegalStateException( + "GpuArrayAggregate overrides columnarEval; transformListColumnView is unused") + } + + private def cudfAgg: cudf.SegmentedReductionAggregation = op match { + case ArrayAggregateDecomposer.SumOp => cudf.SegmentedReductionAggregation.sum() + } + + private def identityScalar(outDType: DType): cudf.Scalar = op match { + case ArrayAggregateDecomposer.SumOp => + dataType match { + case _: ByteType => cudf.Scalar.fromByte(0.toByte) + case _: ShortType => cudf.Scalar.fromShort(0.toShort) + case _: IntegerType => cudf.Scalar.fromInt(0) + case _: LongType => cudf.Scalar.fromLong(0L) + case _: FloatType => cudf.Scalar.fromFloat(0.0f) + case _: DoubleType => cudf.Scalar.fromDouble(0.0) + case _: DecimalType => + // BigDecimal-based fromDecimal picks DECIMAL32/64/128 from the value's precision, + // which does not match the reduced column's fixed width. Bind to the column's + // DType explicitly so ifElse and add don't see a width mismatch. + cudf.Scalar.fromDecimal(java.math.BigInteger.ZERO, outDType) + case other => + throw new IllegalStateException(s"SUM identity not defined for $other") + } + } + + private def combineWithZero( + filled: cudf.ColumnVector, + zeroCv: cudf.ColumnView, + outDType: DType): cudf.ColumnVector = op match { + case ArrayAggregateDecomposer.SumOp => filled.add(zeroCv, outDType) + } + + /** + * Boolean mask: true iff the list is empty *and not null*. Used to substitute op's identity + * only for the empty-list rows, while letting null lists and null elements propagate + * through the subsequent combine-with-zero step. + */ + private def emptyNotNullMask(listCol: cudf.ColumnView): cudf.ColumnVector = { + withResource(listCol.countElements()) { counts => + withResource(cudf.Scalar.fromInt(0)) { zeroInt => + withResource(counts.equalTo(zeroInt)) { isEmpty => + if (argument.nullable) { + withResource(listCol.isNotNull) { isNotNull => + isEmpty.and(isNotNull) + } + } else { + isEmpty.incRefCount() + } + } + } + } + } + + override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + withResource(argument.asInstanceOf[GpuExpression].columnarEval(batch)) { arg => + val transformedData = withResource(makeElementProjectBatch(batch, arg)) { cb => + function.asInstanceOf[GpuExpression].columnarEval(cb) + } + withResource(transformedData) { transformedData => + val listOfGView = GpuListUtils.replaceListDataColumnAsView( + arg.getBase, transformedData.getBase) + withResource(listOfGView) { listOfGView => + val outDType = GpuColumnVector.getNonNestedRapidsType(dataType) + // INCLUDE nulls: Spark evaluates `acc + g(x)` and null poisons the accumulator, so + // any null element in the reduced-over list produces null. cuDF also returns null + // for null lists and empty lists. We substitute identity only for the empty-but- + // not-null case; null lists and null-poisoned sums stay null and fall through the + // add(zero) step preserving null. + withResource(listOfGView.listReduce(cudfAgg, cudf.NullPolicy.INCLUDE, outDType)) { + reduced => + withResource(emptyNotNullMask(arg.getBase)) { isEmptyNotNull => + withResource(identityScalar(outDType)) { idScalar => + withResource(isEmptyNotNull.ifElse(idScalar, reduced)) { adjusted => + withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => + GpuColumnVector.from( + combineWithZero(adjusted, zeroCv.getBase, outDType), dataType) + } + } + } + } + } + } + } + } + } +} + + +/** + * Expression-level meta for Spark's ArrayAggregate. Only accepts lambdas that + * ArrayAggregateDecomposer can decompose into (op, g); otherwise falls back to CPU. + */ +class GpuArrayAggregateMeta( + expr: ArrayAggregate, + conf: RapidsConf, + parent: Option[RapidsMeta[_, _, _]], + rule: DataFromReplacementRule) + extends ExprMeta[ArrayAggregate](expr, conf, parent, rule) { + + private var decomposition: Option[ArrayAggregateDecomposer.Decomposition] = None + + override def tagExprForGpu(): Unit = { + val d = ArrayAggregateDecomposer.decompose(expr.merge, expr.finish) + if (d.isEmpty) { + willNotWorkOnGpu( + "ArrayAggregate only supports lambdas of the form (acc, x) -> acc + g(x) with " + + "an identity finish lambda (SUM only for now). Other shapes are not supported.") + return + } + // g's output type must equal the accumulator/zero type so the segmented reduce output + // matches the Spark-expected result type directly. + val body = expr.merge.asInstanceOf[LambdaFunction].function.asInstanceOf[Add] + val gType = body.children(d.get.gChildIndex).dataType + if (!DataType.equalsStructurally(gType, expr.zero.dataType, ignoreNullability = true)) { + willNotWorkOnGpu( + s"g(x) output type ($gType) does not match accumulator/zero type " + + s"(${expr.zero.dataType})") + return + } + decomposition = d + } + + override def convertToGpuImpl(): GpuExpression = { + val d = decomposition.getOrElse( + throw new IllegalStateException("tagExprForGpu must succeed before convertToGpu")) + + val argGpu = childExprs.head.convertToGpu() + val zeroGpu = childExprs(1).convertToGpu() + // childExprs(2) is the merge lambda meta; its first child is the Add body meta, whose + // gChildIndex-th child is the g sub-expression we want on GPU. + val bodyMeta = childExprs(2).childExprs.head + val gGpu = bodyMeta.childExprs(d.gChildIndex).convertToGpu() + val elemVarGpu = GpuNamedLambdaVariable( + d.elemVar.name, d.elemVar.dataType, d.elemVar.nullable, d.elemVar.exprId) + val gLambda = GpuLambdaFunction(gGpu, Seq(elemVarGpu)) + + GpuArrayAggregate(argGpu, zeroGpu, gLambda, d.op) + } +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala new file mode 100644 index 00000000000..85089167432 --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.apache.spark.sql.catalyst.expressions.{Add, Cast, Expression, LambdaFunction, + Literal, Multiply, NamedExpression, NamedLambdaVariable, Subtract} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType} + +// Extends GpuUnitTests so SQLConf.get is available for the default evalMode/failOnError +// parameter on Add/Subtract/Multiply (the field name differs across Spark versions; letting +// Spark apply its own default keeps this test shim-agnostic). +class ArrayAggregateDecomposerSuite extends GpuUnitTests { + import ArrayAggregateDecomposer._ + + // --- helpers ----------------------------------------------------------- + + private def lv(name: String, dt: DataType = IntegerType): NamedLambdaVariable = + NamedLambdaVariable(name, dt, nullable = true, exprId = NamedExpression.newExprId) + + private def merge( + body: Expression, + acc: NamedLambdaVariable, + x: NamedLambdaVariable): LambdaFunction = + LambdaFunction(body, Seq(acc, x)) + + private def identityFinish(acc: NamedLambdaVariable): LambdaFunction = + LambdaFunction(acc, Seq(acc)) + + private def plus(l: Expression, r: Expression): Add = Add(l, r) + private def minus(l: Expression, r: Expression): Subtract = Subtract(l, r) + private def times(l: Expression, r: Expression): Multiply = Multiply(l, r) + + // --- positive cases ---------------------------------------------------- + + test("Add(acc, x) decomposes to SUM with gChildIndex=1") { + val acc = lv("acc") + val x = lv("x") + val m = merge(plus(acc, x), acc, x) + val d = decompose(m, identityFinish(acc)) + assert(d.isDefined) + assert(d.get.op == SumOp) + assert(d.get.gChildIndex == 1) + assert(d.get.accVarExprId == acc.exprId) + assert(d.get.elemVar.exprId == x.exprId) + } + + test("Add(x, acc) (commuted) decomposes with gChildIndex=0") { + val acc = lv("acc") + val x = lv("x") + val m = merge(plus(x, acc), acc, x) + val d = decompose(m, identityFinish(acc)) + assert(d.isDefined) + assert(d.get.gChildIndex == 0) + } + + test("Add(acc, complex g(x)) where g contains no acc ref decomposes") { + val acc = lv("acc", LongType) + val x = lv("x", IntegerType) + // g = x * 2 + 1 (as Long after Cast) + val g = plus(times(x, Literal(2)), Literal(1)) + val m = merge(plus(acc, Cast(g, LongType)), acc, x) + val d = decompose(m, identityFinish(acc)) + assert(d.isDefined) + assert(d.get.gChildIndex == 1) + } + + test("acc wrapped in a Cast on the acc-side is still accepted") { + val acc = lv("acc", LongType) + val x = lv("x", IntegerType) + // body = Cast(acc, Int) + x + val m = merge(plus(Cast(acc, IntegerType), x), acc, x) + val d = decompose(m, identityFinish(acc)) + assert(d.isDefined) + assert(d.get.gChildIndex == 1) + } + + test("chained Cast on acc side is unwrapped") { + val acc = lv("acc") + val x = lv("x") + // body = Cast(Cast(acc, Long), Int) + x + val accDoubleCast = Cast(Cast(acc, LongType), IntegerType) + val m = merge(plus(accDoubleCast, x), acc, x) + val d = decompose(m, identityFinish(acc)) + assert(d.isDefined) + } + + // --- negative cases ---------------------------------------------------- + + test("Subtract rejected (only Add is SUM)") { + val acc = lv("acc") + val x = lv("x") + val m = merge(minus(acc, x), acc, x) + assert(decompose(m, identityFinish(acc)).isEmpty) + } + + test("Multiply rejected") { + val acc = lv("acc") + val x = lv("x") + val m = merge(times(acc, x), acc, x) + assert(decompose(m, identityFinish(acc)).isEmpty) + } + + test("g that references acc is rejected") { + val acc = lv("acc") + val x = lv("x") + // g = acc * x — references acc + val m = merge(plus(acc, times(acc, x)), acc, x) + assert(decompose(m, identityFinish(acc)).isEmpty) + } + + test("both sides reference acc is rejected") { + val acc = lv("acc") + val x = lv("x") + val m = merge(plus(acc, acc), acc, x) + assert(decompose(m, identityFinish(acc)).isEmpty) + } + + test("neither side is a pure acc ref is rejected") { + val acc = lv("acc") + val x = lv("x") + // body = (acc + 1) + x — left has acc but isn't a naked acc ref + val leftWithPlusOne = plus(acc, Literal(1)) + val m = merge(plus(leftWithPlusOne, x), acc, x) + assert(decompose(m, identityFinish(acc)).isEmpty) + } + + test("non-identity finish rejected") { + val acc = lv("acc") + val x = lv("x") + val m = merge(plus(acc, x), acc, x) + val finishAcc = lv("finishAcc") + val nonIdentityFinish = + LambdaFunction(plus(finishAcc, Literal(1)), Seq(finishAcc)) + assert(decompose(m, nonIdentityFinish).isEmpty) + } + + test("finish referencing a different variable id is rejected") { + val acc = lv("acc") + val x = lv("x") + val m = merge(plus(acc, x), acc, x) + // finish's body references a NamedLambdaVariable with a *different* exprId, + // so it is not the identity over the finish's arg. + val finishAcc = lv("finishAcc") + val someOther = lv("other") // different exprId + val badFinish = LambdaFunction(someOther, Seq(finishAcc)) + assert(decompose(m, badFinish).isEmpty) + } + + test("merge with wrong arg count rejected") { + val acc = lv("acc") + val x = lv("x") + val extra = lv("extra") + val m = LambdaFunction(plus(acc, x), Seq(acc, x, extra)) + assert(decompose(m, identityFinish(acc)).isEmpty) + } + + test("merge that isn't a LambdaFunction at all is rejected") { + val acc = lv("acc") + // Pass something that isn't a lambda. + assert(decompose(plus(Literal(1), Literal(2)), identityFinish(acc)).isEmpty) + } + + test("finish that isn't a LambdaFunction rejected") { + val acc = lv("acc") + val x = lv("x") + val m = merge(plus(acc, x), acc, x) + assert(decompose(m, Literal(0)).isEmpty) + } +} diff --git a/tools/generated_files/352/operatorsScore.csv b/tools/generated_files/352/operatorsScore.csv index 823d05ea694..06525f664cb 100644 --- a/tools/generated_files/352/operatorsScore.csv +++ b/tools/generated_files/352/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/352/supportedExprs.csv b/tools/generated_files/352/supportedExprs.csv index 6f70ef61b39..aad208ec635 100644 --- a/tools/generated_files/352/supportedExprs.csv +++ b/tools/generated_files/352/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA From 1a655045cc7998016de8a36b39805782b923979f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 22 Apr 2026 14:52:47 +0800 Subject: [PATCH 02/12] refactor Signed-off-by: Haoyang Li --- .../python/higher_order_functions_test.py | 149 ++++--- .../spark/rapids/higherOrderFunctions.scala | 407 +++++++++++++----- .../ArrayAggregateDecomposerSuite.scala | 243 ++++++----- 3 files changed, 522 insertions(+), 277 deletions(-) diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 3a7f72b6624..3d3b5c8afaa 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -36,20 +36,61 @@ def do_project(spark): # --- ArrayAggregate tests --- # -# Covers the decomposable SUM path: lambdas of the form `(acc, x) -> acc + g(x)` with an -# identity finish. Non-decomposable shapes must fall back to CPU. +# The decomposer accepts lambdas of the form `(acc, x) -> op(acc, g(x))` with an identity +# finish, where `op` is one of SUM/PRODUCT/MAX/MIN/ALL/ANY. Other shapes fall back to CPU. + + +# Happy path for each supported numeric op. Product uses a narrow range to keep the test +# output numerically tame (GPU and CPU both wrap consistently, but small numbers make the +# test easier to read when debugging a failure). +@pytest.mark.parametrize('lambda_sql, init_sql, gen_max', [ + ('(acc, x) -> acc + CAST(x as BIGINT)', '0L', 100), + ('(acc, x) -> acc * CAST(x as BIGINT)', '1L', 3), + ('(acc, x) -> greatest(acc, CAST(x as BIGINT))', '-9223372036854775808L', 100), + ('(acc, x) -> least(acc, CAST(x as BIGINT))', '9223372036854775807L', 100), +], ids=['sum', 'product', 'max', 'min']) +@disable_ansi_mode +def test_array_aggregate_numeric_ops(lambda_sql, init_sql, gen_max): + gen = IntegerGen(min_val=-gen_max, max_val=gen_max) + def do_it(spark): + return unary_op_df(spark, ArrayGen(gen, max_length=8)).selectExpr( + f'aggregate(a, {init_sql}, {lambda_sql}) as res') + assert_gpu_and_cpu_are_equal_collect(do_it) -# Simple: sum elements of an array with zero = 0. Accumulator type is promoted to long -# via the CAST in g, matching the zero's LongType. + +# Happy path for the boolean ops. Elements must be non-null because cuDF's segmented ALL/ +# ANY with INCLUDE nulls don't match Spark's AND/OR 3VL for mixed null+bool (specifically, +# `false AND null = false` short-circuit; `true OR null = true`). The tag-time guard falls +# back to CPU when the element type is nullable, so here we use a non-nullable BooleanGen. +@pytest.mark.parametrize('lambda_sql, init_sql', [ + ('(acc, x) -> acc AND x', 'true'), + ('(acc, x) -> acc OR x', 'false'), +], ids=['all', 'any']) @disable_ansi_mode -def test_array_aggregate_sum_int_to_long(): - assert_gpu_and_cpu_are_equal_collect( - lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=15)).selectExpr( - 'aggregate(a, 0L, (acc, x) -> acc + CAST(x as BIGINT)) as sum')) +def test_array_aggregate_boolean_ops(lambda_sql, init_sql): + non_null_bool = BooleanGen(nullable=False) + def do_it(spark): + return unary_op_df(spark, ArrayGen(non_null_bool, max_length=8)).selectExpr( + f'aggregate(a, {init_sql}, {lambda_sql}) as res') + assert_gpu_and_cpu_are_equal_collect(do_it) -# Count-if pattern: sum of CASE WHEN predicate THEN 1 ELSE 0 END. -# This is the structural twin of the client's real workload. +# When array elements may contain nulls, ALL/ANY must fall back to CPU (cuDF's INCLUDE- +# nulls semantics don't match Spark's AND/OR 3VL). +@pytest.mark.parametrize('lambda_sql, init_sql', [ + ('(acc, x) -> acc AND x', 'true'), + ('(acc, x) -> acc OR x', 'false'), +], ids=['all', 'any']) +@disable_ansi_mode +@allow_non_gpu('ProjectExec') +def test_array_aggregate_boolean_ops_nullable_elements_fallback(lambda_sql, init_sql): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, ArrayGen(boolean_gen, max_length=8)).selectExpr( + f'aggregate(a, {init_sql}, {lambda_sql}) as res'), + 'ArrayAggregate') + + +# Count-if pattern (structural twin of the client's real workload). @disable_ansi_mode def test_array_aggregate_count_if_int(): assert_gpu_and_cpu_are_equal_collect( @@ -58,13 +99,10 @@ def test_array_aggregate_count_if_int(): 'aggregate(a, 0L, (acc, x) -> acc + CAST(CASE WHEN x IS NULL THEN 1 ELSE 0 END as BIGINT)) as null_cnt')) -# Client's actual pattern (simplified to two fields): filter + aggregate with split / GetArrayItem / IN. -# The string data is synthesized with 4-space separators so each element has enough columns. +# Client's actual pattern (simplified): filter + aggregate with split / GetArrayItem / IN. @disable_ansi_mode def test_array_aggregate_client_pattern(): - # Generate strings like "aa bb cc dd ee" so the split yields >2 pieces. field_gen = StringGen('[a-z]{2}') - # Build strings via concat_ws in SQL; use a simple array of ~5 strings. def do_it(spark): df = unary_op_df(spark, ArrayGen(field_gen, max_length=5)) return df.selectExpr(""" @@ -89,7 +127,7 @@ def test_array_aggregate_non_zero_init(): 'aggregate(a, 100L, (acc, x) -> acc + CAST(x as BIGINT)) as sum_with_init')) -# Null and empty arrays. Spark semantics: null array -> null, empty array -> finish(init) = init. +# null array -> null, empty array -> finish(init) = init. @disable_ansi_mode def test_array_aggregate_null_array(): assert_gpu_and_cpu_are_equal_collect( @@ -100,7 +138,6 @@ def test_array_aggregate_null_array(): @disable_ansi_mode def test_array_aggregate_empty_array(): def do_it(spark): - # Array column with some empty arrays interspersed. return spark.createDataFrame( [([1, 2, 3],), ([],), ([7],), ([],)], 'a array').selectExpr( @@ -108,52 +145,11 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# Non-decomposable lambda must fall back to CPU. `acc - x` is not associative / not in whitelist. -@disable_ansi_mode -@allow_non_gpu('ProjectExec') -def test_array_aggregate_subtract_falls_back(): - assert_gpu_fallback_collect( - lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( - 'aggregate(a, 0L, (acc, x) -> acc - CAST(x as BIGINT)) as diff'), - 'ArrayAggregate') - - -# Non-identity finish must fall back. -@disable_ansi_mode -@allow_non_gpu('ProjectExec') -def test_array_aggregate_non_identity_finish_falls_back(): - assert_gpu_fallback_collect( - lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( - 'aggregate(a, 0L, (acc, x) -> acc + CAST(x as BIGINT), acc -> acc * 2) as doubled'), - 'ArrayAggregate') - - -# g that references the accumulator must fall back. -@disable_ansi_mode -@allow_non_gpu('ProjectExec') -def test_array_aggregate_g_references_acc_falls_back(): - assert_gpu_fallback_collect( - lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( - 'aggregate(a, 0L, (acc, x) -> acc + acc * CAST(x as BIGINT)) as recur'), - 'ArrayAggregate') - - -# Multiplicative accumulator (not in the SUM whitelist) must fall back. -@disable_ansi_mode -@allow_non_gpu('ProjectExec') -def test_array_aggregate_product_falls_back(): - assert_gpu_fallback_collect( - lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( - 'aggregate(a, 1L, (acc, x) -> acc * CAST(x as BIGINT)) as prod'), - 'ArrayAggregate') - - -# Lambda body references an outer attribute ("b") — exercises boundIntermediate plumbing. +# Lambda body references an outer attribute — exercises boundIntermediate plumbing. @disable_ansi_mode def test_array_aggregate_lambda_refs_outer_column(): def do_it(spark): return two_col_df(spark, ArrayGen(int_gen, max_length=10), int_gen).selectExpr( - # g(x) = (x + b) — b is a closed-over outer column, not the acc. 'aggregate(a, 0L, (acc, x) -> acc + CAST(x + b as BIGINT)) as sum_with_outer') assert_gpu_and_cpu_are_equal_collect(do_it) @@ -185,7 +181,7 @@ def test_array_aggregate_deeper_g_body(): 'aggregate(a, 0L, (acc, x) -> acc + CAST(x * 2 + 1 as BIGINT)) as sum_poly')) -# Long-overflow wrap-around: in non-ANSI mode both Spark SUM and cudf SUM wrap silently. +# Long-overflow wrap-around matches between Spark SUM and cudf SUM in non-ANSI mode. @disable_ansi_mode def test_array_aggregate_long_overflow_wraps(): def do_it(spark): @@ -195,10 +191,8 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# Decimal SUM: Spark's ArrayAggregate requires merge.dataType == zero.dataType exactly. For -# DECIMAL(10,2) + DECIMAL(10,2) the result is DECIMAL(11,2), which fails analysis against a -# DECIMAL(10,2) zero. The working pattern is to widen zero to DECIMAL(38,2) (Spark's cap) -# and Cast the element to match, so `acc + Cast(x, DECIMAL(38,2))` stays at DECIMAL(38,2). +# Decimal SUM: zero must be widened to DECIMAL(38,2) (Spark's cap) with the element Cast to +# match so that merge.dataType == zero.dataType (Spark's checkInputDataTypes). @disable_ansi_mode def test_array_aggregate_decimal_sum(): decimal_gen = DecimalGen(precision=10, scale=2) @@ -207,3 +201,32 @@ def do_it(spark): 'aggregate(a, CAST(0 as DECIMAL(38,2)), ' '(acc, x) -> acc + CAST(x as DECIMAL(38,2))) as dec_sum') assert_gpu_and_cpu_are_equal_collect(do_it) + + +# Shapes the decomposer rejects must fall back to CPU. Covered: non-associative op +# (Subtract, Divide), variadic op with wrong arity (Greatest with 3 children), and a lambda +# whose g sub-expression references the accumulator. +@pytest.mark.parametrize('lambda_sql, init_sql', [ + ('(acc, x) -> acc - CAST(x as BIGINT)', '0L'), + ('(acc, x) -> CAST(acc / CAST(x + 1 as BIGINT) as BIGINT)', '1L'), + ('(acc, x) -> greatest(acc, CAST(x as BIGINT), CAST(x * 2 as BIGINT))', '-999L'), + ('(acc, x) -> acc + acc * CAST(x as BIGINT)', '0L'), +], ids=['subtract', 'divide', 'greatest-3ary', 'g-refs-acc']) +@disable_ansi_mode +@allow_non_gpu('ProjectExec') +def test_array_aggregate_fallback_shapes(lambda_sql, init_sql): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( + f'aggregate(a, {init_sql}, {lambda_sql}) as res'), + 'ArrayAggregate') + + +# Non-identity finish is kept as its own test because its SQL shape (4-arg aggregate with +# a separate finish lambda) differs from the merge-only fallbacks above. +@disable_ansi_mode +@allow_non_gpu('ProjectExec') +def test_array_aggregate_non_identity_finish_falls_back(): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + CAST(x as BIGINT), acc -> acc * 2) as doubled'), + 'ArrayAggregate') diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index 0f59ccf9b0d..ec5468cdb8b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -26,7 +26,7 @@ import com.nvidia.spark.rapids.jni.GpuMapZipWithUtils import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.catalyst.expressions.{Add, ArrayAggregate, Attribute, AttributeReference, AttributeSeq, Cast, Expression, ExprId, LambdaFunction, NamedExpression, NamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{Add, And, ArrayAggregate, Attribute, AttributeReference, AttributeSeq, Cast, Expression, ExprId, Greatest, LambdaFunction, Least, Multiply, NamedExpression, NamedLambdaVariable, Or} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, ShortType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -897,28 +897,211 @@ case class GpuMapFilter(argument: Expression, } +// ===================================================================================== +// AggOp: one case object per supported segmented reduction. Adding a new op is three +// things: define the case object, wire matchBinary to its Catalyst shape, and append it +// to ArrayAggregateDecomposer.allOps. The op owns: its cuDF aggregation + null policy, the +// identity scalar used to back-fill rows where no element contributed, and the combine- +// with-zero step. +// ===================================================================================== + +sealed trait AggOp { + def name: String + def cudfAgg: cudf.SegmentedReductionAggregation + def nullPolicy: cudf.NullPolicy + /** + * Identity element at the given Spark type. Built with a cuDF DType matching the + * reduced column so downstream ifElse / binaryOp don't hit a width mismatch. + */ + def identityScalar(sparkType: DataType, cudfDType: DType): cudf.Scalar + /** `result = reduced OP zero`, typed to outDType, with Spark-matching null propagation. */ + def combineWithZero( + reduced: cudf.ColumnVector, + zero: cudf.ColumnView, + outDType: DType): cudf.ColumnVector + /** Return (left, right) if the body is this op's Catalyst shape. */ + def matchBinary(body: Expression): Option[(Expression, Expression)] + /** Is this Spark data type supported for this op's accumulator / result? */ + def supportsType(sparkType: DataType): Boolean +} + +case object SumOp extends AggOp { + val name = "SUM" + def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.sum() + // INCLUDE: Spark iteratively computes `acc + x` and null poisons the accumulator, so + // one null element anywhere in the list yields null. + val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE + def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { + case _: ByteType => cudf.Scalar.fromByte(0.toByte) + case _: ShortType => cudf.Scalar.fromShort(0.toShort) + case _: IntegerType => cudf.Scalar.fromInt(0) + case _: LongType => cudf.Scalar.fromLong(0L) + case _: FloatType => cudf.Scalar.fromFloat(0.0f) + case _: DoubleType => cudf.Scalar.fromDouble(0.0) + case _: DecimalType => + // fromDecimal(BigDecimal) picks DECIMAL32/64/128 from the value's precision, which + // may not match the reduced column's fixed width. Bind the DType explicitly. + cudf.Scalar.fromDecimal(java.math.BigInteger.ZERO, cudfT) + case other => throw new IllegalStateException(s"SUM identity not defined for $other") + } + def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.add(z, out) + def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { + case a: Add => Some((a.left, a.right)) + case _ => None + } + def supportsType(t: DataType): Boolean = t match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType | _: DecimalType => true + case _ => false + } +} + +case object ProductOp extends AggOp { + val name = "PRODUCT" + def cudfAgg: cudf.SegmentedReductionAggregation = + cudf.SegmentedReductionAggregation.product() + val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE + def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { + case _: ByteType => cudf.Scalar.fromByte(1.toByte) + case _: ShortType => cudf.Scalar.fromShort(1.toShort) + case _: IntegerType => cudf.Scalar.fromInt(1) + case _: LongType => cudf.Scalar.fromLong(1L) + case _: FloatType => cudf.Scalar.fromFloat(1.0f) + case _: DoubleType => cudf.Scalar.fromDouble(1.0) + case other => + throw new IllegalStateException(s"PRODUCT identity not defined for $other") + } + def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.mul(z, out) + def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { + case m: Multiply => Some((m.left, m.right)) + case _ => None + } + def supportsType(t: DataType): Boolean = t match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType => true + case _ => false + } +} + /** - * Decomposes a Spark ArrayAggregate's merge lambda of shape `(acc, x) -> op(acc, g(x))` into - * a map-reduce form executable with cuDF segmented-reduction APIs. Currently supports SUM. + * MaxOp / MinOp share EXCLUDE null policy: Spark's Greatest / Least skip null operands, + * so an all-null list reduces to null (no non-null contributor) and should then fold + * back to zero via the identity substitution. */ -object ArrayAggregateDecomposer { +sealed trait ExtremumOp extends AggOp { + val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.EXCLUDE + def supportsType(t: DataType): Boolean = t match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | + _: FloatType | _: DoubleType => true + case _ => false + } +} - sealed trait AggOp - case object SumOp extends AggOp +case object MaxOp extends ExtremumOp { + val name = "MAX" + def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.max() + def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { + case _: ByteType => cudf.Scalar.fromByte(Byte.MinValue) + case _: ShortType => cudf.Scalar.fromShort(Short.MinValue) + case _: IntegerType => cudf.Scalar.fromInt(Int.MinValue) + case _: LongType => cudf.Scalar.fromLong(Long.MinValue) + case _: FloatType => cudf.Scalar.fromFloat(Float.NegativeInfinity) + case _: DoubleType => cudf.Scalar.fromDouble(Double.NegativeInfinity) + case other => throw new IllegalStateException(s"MAX identity not defined for $other") + } + // Element-wise max with Spark's null propagation: if either side is null, result is null. + // cuDF has no direct MAX BinaryOp (only NULL_MAX which treats null as smallest), so use + // a compare + ifElse; null in the compare's output propagates to ifElse. + def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) + : cudf.ColumnVector = { + withResource(r.greaterThan(z)) { rGreater => + rGreater.ifElse(r, z) + } + } + def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { + case g: Greatest if g.children.size == 2 => Some((g.children.head, g.children(1))) + case _ => None + } +} - /** - * @param op the segmented reduction aggregation operator - * @param gChildIndex 0 if `g` is the left child of the merge body's binary op, 1 if right - * @param accVarExprId the accumulator NamedLambdaVariable's exprId - * @param elemVar the element NamedLambdaVariable (used to build the g lambda) - */ - case class Decomposition( - op: AggOp, - gChildIndex: Int, - accVarExprId: ExprId, - elemVar: NamedLambdaVariable) +case object MinOp extends ExtremumOp { + val name = "MIN" + def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.min() + def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { + case _: ByteType => cudf.Scalar.fromByte(Byte.MaxValue) + case _: ShortType => cudf.Scalar.fromShort(Short.MaxValue) + case _: IntegerType => cudf.Scalar.fromInt(Int.MaxValue) + case _: LongType => cudf.Scalar.fromLong(Long.MaxValue) + case _: FloatType => cudf.Scalar.fromFloat(Float.PositiveInfinity) + case _: DoubleType => cudf.Scalar.fromDouble(Double.PositiveInfinity) + case other => throw new IllegalStateException(s"MIN identity not defined for $other") + } + def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) + : cudf.ColumnVector = { + withResource(r.lessThan(z)) { rLess => + rLess.ifElse(r, z) + } + } + def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { + case l: Least if l.children.size == 2 => Some((l.children.head, l.children(1))) + case _ => None + } +} + +case object AllOp extends AggOp { + val name = "ALL" + def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.all() + // INCLUDE: matches Spark's 3VL for AND (null AND true = null, null AND false = false). + val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE + def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = cudf.Scalar.fromBool(true) + def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.and(z) + def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { + case a: And => Some((a.left, a.right)) + case _ => None + } + def supportsType(t: DataType): Boolean = t.isInstanceOf[BooleanType] +} - def decompose(merge: Expression, finish: Expression): Option[Decomposition] = { +case object AnyOp extends AggOp { + val name = "ANY" + def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.any() + val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE + def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = cudf.Scalar.fromBool(false) + def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.or(z) + def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { + case o: Or => Some((o.left, o.right)) + case _ => None + } + def supportsType(t: DataType): Boolean = t.isInstanceOf[BooleanType] +} + + +/** + * Result of successfully matching a Spark ArrayAggregate's merge lambda against a + * registered AggOp. + * + * @param op the chosen aggregation operator + * @param gChildIndex 0 if `g` is the left child of the merge body's binary op, 1 if right + * @param accVarExprId the accumulator NamedLambdaVariable's exprId + * @param elemVar the element NamedLambdaVariable (used to build the g lambda) + */ +case class ArrayAggregateDecomposition( + op: AggOp, + gChildIndex: Int, + accVarExprId: ExprId, + elemVar: NamedLambdaVariable) + + +/** + * Decomposes a Spark ArrayAggregate's merge lambda of shape `(acc, x) -> op(acc, g(x))` + * where `op` is one of the registered AggOps and the finish lambda is identity. + */ +object ArrayAggregateDecomposer { + + /** All ops the decomposer will try, in order. */ + val allOps: Seq[AggOp] = Seq(SumOp, ProductOp, MaxOp, MinOp, AllOp, AnyOp) + + def decompose(merge: Expression, finish: Expression): Option[ArrayAggregateDecomposition] = { val mergeLambda = merge match { case lf: LambdaFunction => lf case _ => return None @@ -929,18 +1112,17 @@ object ArrayAggregateDecomposer { } if (!isFinishIdentity(finish)) return None - mergeLambda.function match { - case add: Add => - val accId = accVar.exprId - if (isAccRef(add.left, accId) && !containsAccRef(add.right, accId)) { - Some(Decomposition(SumOp, 1, accId, elemVar)) - } else if (isAccRef(add.right, accId) && !containsAccRef(add.left, accId)) { - Some(Decomposition(SumOp, 0, accId, elemVar)) - } else { - None - } - case _ => None - } + val body = mergeLambda.function + val accId = accVar.exprId + allOps.view.flatMap { op => + op.matchBinary(body).flatMap { case (l, r) => + if (isAccRef(l, accId) && !containsAccRef(r, accId)) { + Some(ArrayAggregateDecomposition(op, 1, accId, elemVar)) + } else if (isAccRef(r, accId) && !containsAccRef(l, accId)) { + Some(ArrayAggregateDecomposition(op, 0, accId, elemVar)) + } else None + } + }.headOption } private def isFinishIdentity(finish: Expression): Boolean = finish match { @@ -963,20 +1145,20 @@ object ArrayAggregateDecomposer { /** - * GPU implementation of ArrayAggregate restricted to lambdas decomposable via - * ArrayAggregateDecomposer. Runtime steps: + * GPU implementation of ArrayAggregate for lambdas decomposable via ArrayAggregateDecomposer. + * Runtime steps: * 1. Evaluate g(x) over the array children (reusing GpuArrayTransformBase's explode path). * 2. Rewrap as list with the original offsets and validity. - * 3. cuDF segmented reduce. - * 4. Replace null reduction results (from empty lists) with op's identity. - * 5. Combine with zero: result = zero op filled. - * 6. Restore null for rows where the input array was null. + * 3. cuDF segmented reduce with the op's null policy. + * 4. Substitute op's identity into rows where reduce returned null due to "no elements + * contributed" (the exact condition depends on null policy; see `substituteMask`). + * 5. Combine with zero: `result = reduced OP zero`. */ case class GpuArrayAggregate( argument: Expression, zero: Expression, function: Expression, - op: ArrayAggregateDecomposer.AggOp, + op: AggOp, isBound: Boolean = false, boundIntermediate: Seq[GpuExpression] = Seq.empty) extends GpuArrayTransformBase { @@ -986,7 +1168,6 @@ case class GpuArrayAggregate( override def prettyName: String = "array_aggregate" - // Include zero as a child so analyzer / optimizer passes can see it. override def children: Seq[Expression] = argument :: zero :: function :: Nil override def bind(input: AttributeSeq): GpuExpression = { @@ -995,8 +1176,6 @@ case class GpuArrayAggregate( GpuArrayAggregate(boundArg, boundZero, boundFunc, op, isBound = true, boundInter) } - // We override columnarEval entirely; the base class's template isn't a fit because - // the lambda output still needs a segmented reduction plus combine-with-zero. override protected def transformListColumnView( lambdaTransformedCV: cudf.ColumnView, arg: cudf.ColumnView): GpuColumnVector = { @@ -1004,55 +1183,39 @@ case class GpuArrayAggregate( "GpuArrayAggregate overrides columnarEval; transformListColumnView is unused") } - private def cudfAgg: cudf.SegmentedReductionAggregation = op match { - case ArrayAggregateDecomposer.SumOp => cudf.SegmentedReductionAggregation.sum() - } - - private def identityScalar(outDType: DType): cudf.Scalar = op match { - case ArrayAggregateDecomposer.SumOp => - dataType match { - case _: ByteType => cudf.Scalar.fromByte(0.toByte) - case _: ShortType => cudf.Scalar.fromShort(0.toShort) - case _: IntegerType => cudf.Scalar.fromInt(0) - case _: LongType => cudf.Scalar.fromLong(0L) - case _: FloatType => cudf.Scalar.fromFloat(0.0f) - case _: DoubleType => cudf.Scalar.fromDouble(0.0) - case _: DecimalType => - // BigDecimal-based fromDecimal picks DECIMAL32/64/128 from the value's precision, - // which does not match the reduced column's fixed width. Bind to the column's - // DType explicitly so ifElse and add don't see a width mismatch. - cudf.Scalar.fromDecimal(java.math.BigInteger.ZERO, outDType) - case other => - throw new IllegalStateException(s"SUM identity not defined for $other") - } - } - - private def combineWithZero( - filled: cudf.ColumnVector, - zeroCv: cudf.ColumnView, - outDType: DType): cudf.ColumnVector = op match { - case ArrayAggregateDecomposer.SumOp => filled.add(zeroCv, outDType) - } - /** - * Boolean mask: true iff the list is empty *and not null*. Used to substitute op's identity - * only for the empty-list rows, while letting null lists and null elements propagate - * through the subsequent combine-with-zero step. + * Mask of rows where the reduce result must be replaced with the op's identity. + * + * INCLUDE ops (SUM/PRODUCT/ALL/ANY): only empty-and-not-null lists. Null-poisoned + * reduces stay null and propagate through the combine step, matching Spark's iterative + * `acc op null = null` semantics. + * + * EXCLUDE ops (MAX/MIN): any reduce-null over a non-null list — covers both empty lists + * and all-null lists, matching Spark's Greatest/Least which skip nulls entirely. */ - private def emptyNotNullMask(listCol: cudf.ColumnView): cudf.ColumnVector = { - withResource(listCol.countElements()) { counts => - withResource(cudf.Scalar.fromInt(0)) { zeroInt => - withResource(counts.equalTo(zeroInt)) { isEmpty => - if (argument.nullable) { - withResource(listCol.isNotNull) { isNotNull => - isEmpty.and(isNotNull) + private def substituteMask( + listCol: cudf.ColumnView, + reduced: cudf.ColumnVector): cudf.ColumnVector = op.nullPolicy match { + case cudf.NullPolicy.INCLUDE => + withResource(listCol.countElements()) { counts => + withResource(cudf.Scalar.fromInt(0)) { zeroInt => + withResource(counts.equalTo(zeroInt)) { isEmpty => + if (argument.nullable) { + withResource(listCol.isNotNull) { isNotNull => isEmpty.and(isNotNull) } + } else { + isEmpty.incRefCount() } - } else { - isEmpty.incRefCount() } } } - } + case cudf.NullPolicy.EXCLUDE => + withResource(reduced.isNull) { reducedIsNull => + if (argument.nullable) { + withResource(listCol.isNotNull) { isNotNull => reducedIsNull.and(isNotNull) } + } else { + reducedIsNull.incRefCount() + } + } } override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { @@ -1065,23 +1228,30 @@ case class GpuArrayAggregate( arg.getBase, transformedData.getBase) withResource(listOfGView) { listOfGView => val outDType = GpuColumnVector.getNonNestedRapidsType(dataType) - // INCLUDE nulls: Spark evaluates `acc + g(x)` and null poisons the accumulator, so - // any null element in the reduced-over list produces null. cuDF also returns null - // for null lists and empty lists. We substitute identity only for the empty-but- - // not-null case; null lists and null-poisoned sums stay null and fall through the - // add(zero) step preserving null. - withResource(listOfGView.listReduce(cudfAgg, cudf.NullPolicy.INCLUDE, outDType)) { - reduced => - withResource(emptyNotNullMask(arg.getBase)) { isEmptyNotNull => - withResource(identityScalar(outDType)) { idScalar => - withResource(isEmptyNotNull.ifElse(idScalar, reduced)) { adjusted => - withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => - GpuColumnVector.from( - combineWithZero(adjusted, zeroCv.getBase, outDType), dataType) + withResource(listOfGView.listReduce(op.cudfAgg, op.nullPolicy, outDType)) { reduced => + withResource(substituteMask(arg.getBase, reduced)) { mask => + withResource(op.identityScalar(dataType, outDType)) { idScalar => + withResource(mask.ifElse(idScalar, reduced)) { adjusted => + withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => + withResource( + op.combineWithZero(adjusted, zeroCv.getBase, outDType)) { combined => + // Unconditionally restore null for rows where the input list itself + // was null. Not all cuDF binary ops (e.g. GREATER / LOGICAL_AND) + // propagate null the way Spark's 3VL would, so we can't rely on the + // combine step to preserve null-list semantics. Doing the restore + // even when the argument is declared non-nullable is a no-op (isNull + // is all-false) and avoids fragile reliance on the nullable flag. + withResource(arg.getBase.isNull) { isNullList => + withResource(cudf.Scalar.fromNull(outDType)) { nullScalar => + GpuColumnVector.from( + isNullList.ifElse(nullScalar, combined), dataType) + } + } } } } } + } } } } @@ -1091,8 +1261,9 @@ case class GpuArrayAggregate( /** - * Expression-level meta for Spark's ArrayAggregate. Only accepts lambdas that - * ArrayAggregateDecomposer can decompose into (op, g); otherwise falls back to CPU. + * Expression-level meta for Spark's ArrayAggregate. Accepts lambdas that + * ArrayAggregateDecomposer can decompose into one of the registered AggOps with an + * identity finish; falls back to CPU otherwise. */ class GpuArrayAggregateMeta( expr: ArrayAggregate, @@ -1101,26 +1272,50 @@ class GpuArrayAggregateMeta( rule: DataFromReplacementRule) extends ExprMeta[ArrayAggregate](expr, conf, parent, rule) { - private var decomposition: Option[ArrayAggregateDecomposer.Decomposition] = None + private var decomposition: Option[ArrayAggregateDecomposition] = None override def tagExprForGpu(): Unit = { val d = ArrayAggregateDecomposer.decompose(expr.merge, expr.finish) if (d.isEmpty) { willNotWorkOnGpu( - "ArrayAggregate only supports lambdas of the form (acc, x) -> acc + g(x) with " + - "an identity finish lambda (SUM only for now). Other shapes are not supported.") + "ArrayAggregate only supports lambdas of the form (acc, x) -> op(acc, g(x)) " + + "with an identity finish lambda, where op is one of " + + ArrayAggregateDecomposer.allOps.map(_.name).mkString(", ") + ".") + return + } + val decomp = d.get + if (!decomp.op.supportsType(expr.zero.dataType)) { + willNotWorkOnGpu( + s"${decomp.op.name} is not supported on GPU for type ${expr.zero.dataType}") return } // g's output type must equal the accumulator/zero type so the segmented reduce output // matches the Spark-expected result type directly. - val body = expr.merge.asInstanceOf[LambdaFunction].function.asInstanceOf[Add] - val gType = body.children(d.get.gChildIndex).dataType + val body = expr.merge.asInstanceOf[LambdaFunction].function + val gType = decomp.op.matchBinary(body).get match { + case (_, r) if decomp.gChildIndex == 1 => r.dataType + case (l, _) => l.dataType + } if (!DataType.equalsStructurally(gType, expr.zero.dataType, ignoreNullability = true)) { willNotWorkOnGpu( s"g(x) output type ($gType) does not match accumulator/zero type " + s"(${expr.zero.dataType})") return } + // cuDF's segmented ALL/ANY with INCLUDE nulls doesn't match Spark's AND/OR 3VL + // (specifically: `false AND null = false` short-circuit, or `true OR null = true`, are + // both missed by cuDF which returns null whenever any null is present). Fall back to + // CPU when the input array can contain nulls. + if (decomp.op == AllOp || decomp.op == AnyOp) { + expr.argument.dataType match { + case ArrayType(_, containsNull) if containsNull => + willNotWorkOnGpu( + s"${decomp.op.name} is not supported on GPU for arrays that may contain nulls; " + + "cuDF's INCLUDE-nulls semantics don't match Spark's AND/OR 3VL") + return + case _ => + } + } decomposition = d } @@ -1130,8 +1325,10 @@ class GpuArrayAggregateMeta( val argGpu = childExprs.head.convertToGpu() val zeroGpu = childExprs(1).convertToGpu() - // childExprs(2) is the merge lambda meta; its first child is the Add body meta, whose - // gChildIndex-th child is the g sub-expression we want on GPU. + // childExprs(2) is the merge lambda meta; its first child is the op body meta, whose + // gChildIndex-th child is the g sub-expression. For binary catalyst shapes (Add, + // Multiply, And, Or) children are [left, right]; for variadic shapes restricted to + // size==2 (Greatest, Least) children are also [left, right]. So the index lines up. val bodyMeta = childExprs(2).childExprs.head val gGpu = bodyMeta.childExprs(d.gChildIndex).convertToGpu() val elemVarGpu = GpuNamedLambdaVariable( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala index 85089167432..1f9c249df47 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala @@ -16,15 +16,16 @@ package com.nvidia.spark.rapids -import org.apache.spark.sql.catalyst.expressions.{Add, Cast, Expression, LambdaFunction, - Literal, Multiply, NamedExpression, NamedLambdaVariable, Subtract} -import org.apache.spark.sql.types.{DataType, IntegerType, LongType} - -// Extends GpuUnitTests so SQLConf.get is available for the default evalMode/failOnError -// parameter on Add/Subtract/Multiply (the field name differs across Spark versions; letting -// Spark apply its own default keeps this test shim-agnostic). +import org.apache.spark.sql.catalyst.expressions.{Add, And, Cast, Divide, Expression, + Greatest, LambdaFunction, Least, Literal, Multiply, NamedExpression, NamedLambdaVariable, + Or, Subtract} +import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, LongType} + +// Extends GpuUnitTests so SQLConf.get is available for the default evalMode / failOnError +// parameter on Add/Subtract/Multiply/Divide (the field name differs across Spark versions; +// letting Spark apply its own default keeps this test shim-agnostic). class ArrayAggregateDecomposerSuite extends GpuUnitTests { - import ArrayAggregateDecomposer._ + import ArrayAggregateDecomposer.decompose // --- helpers ----------------------------------------------------------- @@ -43,141 +44,165 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { private def plus(l: Expression, r: Expression): Add = Add(l, r) private def minus(l: Expression, r: Expression): Subtract = Subtract(l, r) private def times(l: Expression, r: Expression): Multiply = Multiply(l, r) + private def div(l: Expression, r: Expression): Divide = Divide(l, r) + private def greatest(l: Expression, r: Expression): Greatest = Greatest(Seq(l, r)) + private def least(l: Expression, r: Expression): Least = Least(Seq(l, r)) - // --- positive cases ---------------------------------------------------- - - test("Add(acc, x) decomposes to SUM with gChildIndex=1") { - val acc = lv("acc") - val x = lv("x") - val m = merge(plus(acc, x), acc, x) - val d = decompose(m, identityFinish(acc)) - assert(d.isDefined) - assert(d.get.op == SumOp) - assert(d.get.gChildIndex == 1) + /** Assert decomposition succeeds; returns the ArrayAggregateDecomposition for further checks. */ + private def assertDecomposes( + body: Expression, + acc: NamedLambdaVariable, + x: NamedLambdaVariable, + expectedOp: AggOp, + expectedGChildIndex: Int): ArrayAggregateDecomposition = { + val d = decompose(merge(body, acc, x), identityFinish(acc)) + assert(d.isDefined, s"expected decomposition for body=$body") + assert(d.get.op == expectedOp) + assert(d.get.gChildIndex == expectedGChildIndex) assert(d.get.accVarExprId == acc.exprId) assert(d.get.elemVar.exprId == x.exprId) + d.get } - test("Add(x, acc) (commuted) decomposes with gChildIndex=0") { - val acc = lv("acc") - val x = lv("x") - val m = merge(plus(x, acc), acc, x) - val d = decompose(m, identityFinish(acc)) - assert(d.isDefined) - assert(d.get.gChildIndex == 0) - } - - test("Add(acc, complex g(x)) where g contains no acc ref decomposes") { - val acc = lv("acc", LongType) - val x = lv("x", IntegerType) - // g = x * 2 + 1 (as Long after Cast) - val g = plus(times(x, Literal(2)), Literal(1)) - val m = merge(plus(acc, Cast(g, LongType)), acc, x) - val d = decompose(m, identityFinish(acc)) - assert(d.isDefined) - assert(d.get.gChildIndex == 1) - } - - test("acc wrapped in a Cast on the acc-side is still accepted") { - val acc = lv("acc", LongType) - val x = lv("x", IntegerType) - // body = Cast(acc, Int) + x - val m = merge(plus(Cast(acc, IntegerType), x), acc, x) - val d = decompose(m, identityFinish(acc)) - assert(d.isDefined) - assert(d.get.gChildIndex == 1) - } - - test("chained Cast on acc side is unwrapped") { - val acc = lv("acc") - val x = lv("x") - // body = Cast(Cast(acc, Long), Int) + x - val accDoubleCast = Cast(Cast(acc, LongType), IntegerType) - val m = merge(plus(accDoubleCast, x), acc, x) - val d = decompose(m, identityFinish(acc)) - assert(d.isDefined) + private def assertRejects( + mergeBody: LambdaFunction, + finish: Expression, + reason: String): Unit = { + assert(decompose(mergeBody, finish).isEmpty, reason) } - // --- negative cases ---------------------------------------------------- + // --- positive: one per op --------------------------------------------- - test("Subtract rejected (only Add is SUM)") { - val acc = lv("acc") - val x = lv("x") - val m = merge(minus(acc, x), acc, x) - assert(decompose(m, identityFinish(acc)).isEmpty) + test("Add(acc, x) -> SUM, gChildIndex=1") { + val acc = lv("acc"); val x = lv("x") + assertDecomposes(plus(acc, x), acc, x, SumOp, 1) } - test("Multiply rejected") { - val acc = lv("acc") - val x = lv("x") - val m = merge(times(acc, x), acc, x) - assert(decompose(m, identityFinish(acc)).isEmpty) + test("Add(x, acc) (commuted) -> SUM, gChildIndex=0") { + val acc = lv("acc"); val x = lv("x") + assertDecomposes(plus(x, acc), acc, x, SumOp, 0) + } + + test("Multiply(acc, x) -> PRODUCT") { + val acc = lv("acc"); val x = lv("x") + assertDecomposes(times(acc, x), acc, x, ProductOp, 1) + } + + test("Greatest(acc, x) -> MAX") { + val acc = lv("acc"); val x = lv("x") + assertDecomposes(greatest(acc, x), acc, x, MaxOp, 1) + } + + test("Least(acc, x) -> MIN") { + val acc = lv("acc"); val x = lv("x") + assertDecomposes(least(acc, x), acc, x, MinOp, 1) + } + + test("And(acc, x) -> ALL") { + val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) + assertDecomposes(And(acc, x), acc, x, AllOp, 1) + } + + test("Or(acc, x) -> ANY") { + val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) + assertDecomposes(Or(acc, x), acc, x, AnyOp, 1) + } + + // --- positive: structural variations ---------------------------------- + + test("Complex g(x) with no acc ref still decomposes") { + val acc = lv("acc", LongType); val x = lv("x", IntegerType) + // g = Cast(x * 2 + 1, Long) + val g = Cast(plus(times(x, Literal(2)), Literal(1)), LongType) + assertDecomposes(plus(acc, g), acc, x, SumOp, 1) + } + + test("Cast wrapping the acc side is unwrapped (single layer)") { + val acc = lv("acc", LongType); val x = lv("x", IntegerType) + assertDecomposes(plus(Cast(acc, IntegerType), x), acc, x, SumOp, 1) + } + + test("Cast wrapping the acc side is unwrapped (chained)") { + val acc = lv("acc"); val x = lv("x") + val doubleCastAcc = Cast(Cast(acc, LongType), IntegerType) + assertDecomposes(plus(doubleCastAcc, x), acc, x, SumOp, 1) + } + + // --- negative: wrong shape -------------------------------------------- + + test("Subtract is not an associative op we recognize") { + val acc = lv("acc"); val x = lv("x") + assertRejects(merge(minus(acc, x), acc, x), identityFinish(acc), + "Subtract is not in the registered AggOps") + } + + test("Divide is not an associative op we recognize") { + val acc = lv("acc"); val x = lv("x") + assertRejects(merge(div(acc, x), acc, x), identityFinish(acc), + "Divide is not in the registered AggOps") + } + + test("Greatest with arity != 2 is not decomposed") { + val acc = lv("acc"); val x = lv("x") + val body = Greatest(Seq(acc, x, Literal(1))) + assertRejects(merge(body, acc, x), identityFinish(acc), + "Greatest with 3 children is not a 2-operand op") } test("g that references acc is rejected") { - val acc = lv("acc") - val x = lv("x") - // g = acc * x — references acc - val m = merge(plus(acc, times(acc, x)), acc, x) - assert(decompose(m, identityFinish(acc)).isEmpty) + val acc = lv("acc"); val x = lv("x") + // g = acc * x, references acc + assertRejects(merge(plus(acc, times(acc, x)), acc, x), identityFinish(acc), + "g must not reference acc") } test("both sides reference acc is rejected") { - val acc = lv("acc") - val x = lv("x") - val m = merge(plus(acc, acc), acc, x) - assert(decompose(m, identityFinish(acc)).isEmpty) + val acc = lv("acc"); val x = lv("x") + assertRejects(merge(plus(acc, acc), acc, x), identityFinish(acc), + "neither side is a 'pure non-acc'") } test("neither side is a pure acc ref is rejected") { - val acc = lv("acc") - val x = lv("x") - // body = (acc + 1) + x — left has acc but isn't a naked acc ref - val leftWithPlusOne = plus(acc, Literal(1)) - val m = merge(plus(leftWithPlusOne, x), acc, x) - assert(decompose(m, identityFinish(acc)).isEmpty) + val acc = lv("acc"); val x = lv("x") + // body = (acc + 1) + x + assertRejects(merge(plus(plus(acc, Literal(1)), x), acc, x), identityFinish(acc), + "left side isn't a naked acc ref") } - test("non-identity finish rejected") { - val acc = lv("acc") - val x = lv("x") - val m = merge(plus(acc, x), acc, x) + // --- negative: finish lambda ------------------------------------------ + + test("non-identity finish is rejected") { + val acc = lv("acc"); val x = lv("x") val finishAcc = lv("finishAcc") - val nonIdentityFinish = - LambdaFunction(plus(finishAcc, Literal(1)), Seq(finishAcc)) - assert(decompose(m, nonIdentityFinish).isEmpty) + val badFinish = LambdaFunction(plus(finishAcc, Literal(1)), Seq(finishAcc)) + assertRejects(merge(plus(acc, x), acc, x), badFinish, + "finish that multiplies the accumulator isn't identity") } test("finish referencing a different variable id is rejected") { - val acc = lv("acc") - val x = lv("x") - val m = merge(plus(acc, x), acc, x) - // finish's body references a NamedLambdaVariable with a *different* exprId, - // so it is not the identity over the finish's arg. + val acc = lv("acc"); val x = lv("x") val finishAcc = lv("finishAcc") - val someOther = lv("other") // different exprId - val badFinish = LambdaFunction(someOther, Seq(finishAcc)) - assert(decompose(m, badFinish).isEmpty) + val otherVar = lv("other") + val badFinish = LambdaFunction(otherVar, Seq(finishAcc)) + assertRejects(merge(plus(acc, x), acc, x), badFinish, + "finish body refers to a variable that isn't its own arg") } - test("merge with wrong arg count rejected") { - val acc = lv("acc") - val x = lv("x") - val extra = lv("extra") - val m = LambdaFunction(plus(acc, x), Seq(acc, x, extra)) - assert(decompose(m, identityFinish(acc)).isEmpty) + // --- negative: shape sanity -------------------------------------------- + + test("merge with wrong arg count is rejected") { + val acc = lv("acc"); val x = lv("x"); val extra = lv("extra") + val body = LambdaFunction(plus(acc, x), Seq(acc, x, extra)) + assertRejects(body, identityFinish(acc), "merge must take 2 lambda args") } test("merge that isn't a LambdaFunction at all is rejected") { val acc = lv("acc") - // Pass something that isn't a lambda. assert(decompose(plus(Literal(1), Literal(2)), identityFinish(acc)).isEmpty) } - test("finish that isn't a LambdaFunction rejected") { - val acc = lv("acc") - val x = lv("x") - val m = merge(plus(acc, x), acc, x) - assert(decompose(m, Literal(0)).isEmpty) + test("finish that isn't a LambdaFunction is rejected") { + val acc = lv("acc"); val x = lv("x") + assert(decompose(merge(plus(acc, x), acc, x), Literal(0)).isEmpty) } } From e31d377bbfe4c77c80cd8b521946e0efaf3098d9 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 22 Apr 2026 16:05:17 +0800 Subject: [PATCH 03/12] simplify Signed-off-by: Haoyang Li --- .../spark/rapids/higherOrderFunctions.scala | 219 +++++++++--------- 1 file changed, 104 insertions(+), 115 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index ec5468cdb8b..ebdef7e0716 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -897,34 +897,34 @@ case class GpuMapFilter(argument: Expression, } -// ===================================================================================== -// AggOp: one case object per supported segmented reduction. Adding a new op is three -// things: define the case object, wire matchBinary to its Catalyst shape, and append it -// to ArrayAggregateDecomposer.allOps. The op owns: its cuDF aggregation + null policy, the -// identity scalar used to back-fill rows where no element contributed, and the combine- -// with-zero step. -// ===================================================================================== +// Registered segmented reductions used by GpuArrayAggregate. To add a new op: define the +// case object, wire matchBinary to its Catalyst shape, and append it to +// ArrayAggregateDecomposer.allOps. sealed trait AggOp { def name: String def cudfAgg: cudf.SegmentedReductionAggregation def nullPolicy: cudf.NullPolicy - /** - * Identity element at the given Spark type. Built with a cuDF DType matching the - * reduced column so downstream ifElse / binaryOp don't hit a width mismatch. - */ + /** Identity scalar, built with `cudfDType` so ifElse / binaryOp don't hit width mismatch. */ def identityScalar(sparkType: DataType, cudfDType: DType): cudf.Scalar - /** `result = reduced OP zero`, typed to outDType, with Spark-matching null propagation. */ + /** `reduced OP zero`, typed to outDType, with Spark-matching null propagation. */ def combineWithZero( reduced: cudf.ColumnVector, zero: cudf.ColumnView, outDType: DType): cudf.ColumnVector /** Return (left, right) if the body is this op's Catalyst shape. */ def matchBinary(body: Expression): Option[(Expression, Expression)] - /** Is this Spark data type supported for this op's accumulator / result? */ def supportsType(sparkType: DataType): Boolean } +object AggOp { + /** Numeric types that all basic cuDF reductions (sum/product/max/min) accept. */ + def isPlainNumeric(t: DataType): Boolean = t match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } +} + case object SumOp extends AggOp { val name = "SUM" def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.sum() @@ -932,16 +932,13 @@ case object SumOp extends AggOp { // one null element anywhere in the list yields null. val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { - case _: ByteType => cudf.Scalar.fromByte(0.toByte) - case _: ShortType => cudf.Scalar.fromShort(0.toShort) - case _: IntegerType => cudf.Scalar.fromInt(0) - case _: LongType => cudf.Scalar.fromLong(0L) - case _: FloatType => cudf.Scalar.fromFloat(0.0f) - case _: DoubleType => cudf.Scalar.fromDouble(0.0) - case _: DecimalType => - // fromDecimal(BigDecimal) picks DECIMAL32/64/128 from the value's precision, which - // may not match the reduced column's fixed width. Bind the DType explicitly. - cudf.Scalar.fromDecimal(java.math.BigInteger.ZERO, cudfT) + case ByteType => cudf.Scalar.fromByte(0.toByte) + case ShortType => cudf.Scalar.fromShort(0.toShort) + case IntegerType => cudf.Scalar.fromInt(0) + case LongType => cudf.Scalar.fromLong(0L) + case FloatType => cudf.Scalar.fromFloat(0.0f) + case DoubleType => cudf.Scalar.fromDouble(0.0) + case d: DecimalType => GpuScalar.from(0, d) case other => throw new IllegalStateException(s"SUM identity not defined for $other") } def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.add(z, out) @@ -949,11 +946,8 @@ case object SumOp extends AggOp { case a: Add => Some((a.left, a.right)) case _ => None } - def supportsType(t: DataType): Boolean = t match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | - _: FloatType | _: DoubleType | _: DecimalType => true - case _ => false - } + def supportsType(t: DataType): Boolean = + AggOp.isPlainNumeric(t) || t.isInstanceOf[DecimalType] } case object ProductOp extends AggOp { @@ -962,56 +956,45 @@ case object ProductOp extends AggOp { cudf.SegmentedReductionAggregation.product() val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { - case _: ByteType => cudf.Scalar.fromByte(1.toByte) - case _: ShortType => cudf.Scalar.fromShort(1.toShort) - case _: IntegerType => cudf.Scalar.fromInt(1) - case _: LongType => cudf.Scalar.fromLong(1L) - case _: FloatType => cudf.Scalar.fromFloat(1.0f) - case _: DoubleType => cudf.Scalar.fromDouble(1.0) - case other => - throw new IllegalStateException(s"PRODUCT identity not defined for $other") + case ByteType => cudf.Scalar.fromByte(1.toByte) + case ShortType => cudf.Scalar.fromShort(1.toShort) + case IntegerType => cudf.Scalar.fromInt(1) + case LongType => cudf.Scalar.fromLong(1L) + case FloatType => cudf.Scalar.fromFloat(1.0f) + case DoubleType => cudf.Scalar.fromDouble(1.0) + case other => throw new IllegalStateException(s"PRODUCT identity not defined for $other") } def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.mul(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case m: Multiply => Some((m.left, m.right)) case _ => None } - def supportsType(t: DataType): Boolean = t match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | - _: FloatType | _: DoubleType => true - case _ => false - } + def supportsType(t: DataType): Boolean = AggOp.isPlainNumeric(t) } /** * MaxOp / MinOp share EXCLUDE null policy: Spark's Greatest / Least skip null operands, - * so an all-null list reduces to null (no non-null contributor) and should then fold - * back to zero via the identity substitution. + * so an all-null list reduces to null and then folds back to zero via identity substitution. */ sealed trait ExtremumOp extends AggOp { val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.EXCLUDE - def supportsType(t: DataType): Boolean = t match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | - _: FloatType | _: DoubleType => true - case _ => false - } + def supportsType(t: DataType): Boolean = AggOp.isPlainNumeric(t) } case object MaxOp extends ExtremumOp { val name = "MAX" def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.max() def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { - case _: ByteType => cudf.Scalar.fromByte(Byte.MinValue) - case _: ShortType => cudf.Scalar.fromShort(Short.MinValue) - case _: IntegerType => cudf.Scalar.fromInt(Int.MinValue) - case _: LongType => cudf.Scalar.fromLong(Long.MinValue) - case _: FloatType => cudf.Scalar.fromFloat(Float.NegativeInfinity) - case _: DoubleType => cudf.Scalar.fromDouble(Double.NegativeInfinity) + case ByteType => cudf.Scalar.fromByte(Byte.MinValue) + case ShortType => cudf.Scalar.fromShort(Short.MinValue) + case IntegerType => cudf.Scalar.fromInt(Int.MinValue) + case LongType => cudf.Scalar.fromLong(Long.MinValue) + case FloatType => cudf.Scalar.fromFloat(Float.NegativeInfinity) + case DoubleType => cudf.Scalar.fromDouble(Double.NegativeInfinity) case other => throw new IllegalStateException(s"MAX identity not defined for $other") } - // Element-wise max with Spark's null propagation: if either side is null, result is null. - // cuDF has no direct MAX BinaryOp (only NULL_MAX which treats null as smallest), so use - // a compare + ifElse; null in the compare's output propagates to ifElse. + // cuDF's NULL_MAX treats null as smallest (wrong for Spark), so emulate null-propagating + // max via compare + ifElse; null in the compare's result propagates through ifElse. def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) : cudf.ColumnVector = { withResource(r.greaterThan(z)) { rGreater => @@ -1028,12 +1011,12 @@ case object MinOp extends ExtremumOp { val name = "MIN" def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.min() def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { - case _: ByteType => cudf.Scalar.fromByte(Byte.MaxValue) - case _: ShortType => cudf.Scalar.fromShort(Short.MaxValue) - case _: IntegerType => cudf.Scalar.fromInt(Int.MaxValue) - case _: LongType => cudf.Scalar.fromLong(Long.MaxValue) - case _: FloatType => cudf.Scalar.fromFloat(Float.PositiveInfinity) - case _: DoubleType => cudf.Scalar.fromDouble(Double.PositiveInfinity) + case ByteType => cudf.Scalar.fromByte(Byte.MaxValue) + case ShortType => cudf.Scalar.fromShort(Short.MaxValue) + case IntegerType => cudf.Scalar.fromInt(Int.MaxValue) + case LongType => cudf.Scalar.fromLong(Long.MaxValue) + case FloatType => cudf.Scalar.fromFloat(Float.PositiveInfinity) + case DoubleType => cudf.Scalar.fromDouble(Double.PositiveInfinity) case other => throw new IllegalStateException(s"MIN identity not defined for $other") } def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) @@ -1059,7 +1042,7 @@ case object AllOp extends AggOp { case a: And => Some((a.left, a.right)) case _ => None } - def supportsType(t: DataType): Boolean = t.isInstanceOf[BooleanType] + def supportsType(t: DataType): Boolean = t == BooleanType } case object AnyOp extends AggOp { @@ -1072,7 +1055,7 @@ case object AnyOp extends AggOp { case o: Or => Some((o.left, o.right)) case _ => None } - def supportsType(t: DataType): Boolean = t.isInstanceOf[BooleanType] + def supportsType(t: DataType): Boolean = t == BooleanType } @@ -1195,65 +1178,71 @@ case class GpuArrayAggregate( */ private def substituteMask( listCol: cudf.ColumnView, - reduced: cudf.ColumnVector): cudf.ColumnVector = op.nullPolicy match { - case cudf.NullPolicy.INCLUDE => - withResource(listCol.countElements()) { counts => - withResource(cudf.Scalar.fromInt(0)) { zeroInt => - withResource(counts.equalTo(zeroInt)) { isEmpty => - if (argument.nullable) { - withResource(listCol.isNotNull) { isNotNull => isEmpty.and(isNotNull) } - } else { - isEmpty.incRefCount() - } + reduced: cudf.ColumnVector): cudf.ColumnVector = { + val reducedIsEmpty = op.nullPolicy match { + case cudf.NullPolicy.INCLUDE => + // Empty-and-not-null only. Null-poisoned reduces stay null to match Spark's + // iterative `acc op null = null` semantics. + withResource(listCol.countElements()) { counts => + withResource(cudf.Scalar.fromInt(0)) { zero => + counts.equalTo(zero) } } - } - case cudf.NullPolicy.EXCLUDE => - withResource(reduced.isNull) { reducedIsNull => - if (argument.nullable) { - withResource(listCol.isNotNull) { isNotNull => reducedIsNull.and(isNotNull) } - } else { - reducedIsNull.incRefCount() - } - } + case cudf.NullPolicy.EXCLUDE => + // Any reduce-null: empty list OR all-null list (both mean "no element contributed"), + // matching Spark's Greatest/Least which skip nulls. + reduced.isNull + } + // Exclude null-list rows from the mask so the final null-restoration step handles them. + // For non-nullable columns this is effectively a no-op (isNotNull is all-true). + withResource(reducedIsEmpty) { m => + withResource(listCol.isNotNull) { isNotNull => m.and(isNotNull) } + } } override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { + val outDType = GpuColumnVector.getNonNestedRapidsType(dataType) withResource(argument.asInstanceOf[GpuExpression].columnarEval(batch)) { arg => - val transformedData = withResource(makeElementProjectBatch(batch, arg)) { cb => - function.asInstanceOf[GpuExpression].columnarEval(cb) - } - withResource(transformedData) { transformedData => - val listOfGView = GpuListUtils.replaceListDataColumnAsView( - arg.getBase, transformedData.getBase) - withResource(listOfGView) { listOfGView => - val outDType = GpuColumnVector.getNonNestedRapidsType(dataType) - withResource(listOfGView.listReduce(op.cudfAgg, op.nullPolicy, outDType)) { reduced => - withResource(substituteMask(arg.getBase, reduced)) { mask => - withResource(op.identityScalar(dataType, outDType)) { idScalar => - withResource(mask.ifElse(idScalar, reduced)) { adjusted => - withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => - withResource( - op.combineWithZero(adjusted, zeroCv.getBase, outDType)) { combined => - // Unconditionally restore null for rows where the input list itself - // was null. Not all cuDF binary ops (e.g. GREATER / LOGICAL_AND) - // propagate null the way Spark's 3VL would, so we can't rely on the - // combine step to preserve null-list semantics. Doing the restore - // even when the argument is declared non-nullable is a no-op (isNull - // is all-false) and avoids fragile reliance on the nullable flag. - withResource(arg.getBase.isNull) { isNullList => - withResource(cudf.Scalar.fromNull(outDType)) { nullScalar => - GpuColumnVector.from( - isNullList.ifElse(nullScalar, combined), dataType) - } - } - } - } - } + // Each step chains via `val x = withResource(...) { ... }` so the previous stage's + // intermediate GPU columns are released before the next stage allocates more. The + // exploded batch (can be large for long arrays) is the main thing we want to let go + // of as early as possible. + + // Step 1: g(x) over children + segmented reduce. Releases cb, transformedData, + // listOfGView as soon as the reduced per-row scalar column is materialized. + val reduced: cudf.ColumnVector = + withResource(makeElementProjectBatch(batch, arg)) { cb => + withResource(function.asInstanceOf[GpuExpression].columnarEval(cb)) { + transformedData => + withResource(GpuListUtils.replaceListDataColumnAsView( + arg.getBase, transformedData.getBase)) { listOfGView => + listOfGView.listReduce(op.cudfAgg, op.nullPolicy, outDType) } - } } } + + // Step 2: substitute op's identity for rows the reduce couldn't cover (per + // nullPolicy). Releases reduced, mask, and idScalar. + val adjusted: cudf.ColumnVector = withResource(reduced) { reduced => + withResource(substituteMask(arg.getBase, reduced)) { mask => + withResource(op.identityScalar(dataType, outDType)) { idScalar => + mask.ifElse(idScalar, reduced) + } + } + } + + // Step 3: combine with zero. Releases adjusted and zeroCv. + val combined: cudf.ColumnVector = withResource(adjusted) { adjusted => + withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => + op.combineWithZero(adjusted, zeroCv.getBase, outDType) + } + } + + // Step 4: restore null on rows where the input list itself was null. cuDF GREATER / + // LOGICAL_AND / LOGICAL_OR don't propagate null the way Spark's 3VL would, so the + // combine step alone can't preserve it. mergeNulls short-circuits if no nulls. + withResource(combined) { combined => + GpuColumnVector.from(NullUtilities.mergeNulls(combined, arg.getBase), dataType) } } } From 0a0f2ab8cadf798d6baef453b8c29f7a4b24be9a Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 23 Apr 2026 14:22:13 +0800 Subject: [PATCH 04/12] Restrict ExtremumOp to integral types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cuDF's segmented max / min and the combineWithZero compare + ifElse both follow IEEE 754, where fmax(NaN, x) = x (NaN is absorbed). Spark's Greatest / Least use Double.compare, which treats NaN as larger than every other value and propagates it. For an array column containing NaN, GPU would return a non-NaN result while CPU would return NaN — a data-correctness divergence flagged on the PR. Since customer workloads for ArrayAggregate MAX / MIN are integral-typed, take the conservative route: narrow ExtremumOp.supportsType to {Byte, Short, Int, Long} and fall back to CPU on Float / Double. Precise NaN propagation would require two extra segmented reduces per batch and explicit NaN handling in combineWithZero; leaving that for a follow-up if a real workload needs it. Added an integration test that verifies the Float / Double fallback. --- .../main/python/higher_order_functions_test.py | 17 +++++++++++++++++ .../spark/rapids/higherOrderFunctions.scala | 11 ++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 3d3b5c8afaa..7389d65950b 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -230,3 +230,20 @@ def test_array_aggregate_non_identity_finish_falls_back(): lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=5)).selectExpr( 'aggregate(a, 0L, (acc, x) -> acc + CAST(x as BIGINT), acc -> acc * 2) as doubled'), 'ArrayAggregate') + + +# MAX / MIN on float/double arrays must fall back: cuDF's segmented max/min follow IEEE 754 +# where NaN is absorbed (`fmax(NaN, x) = x`), while Spark's `Greatest`/`Least` propagate NaN +# via `Double.compare`. Rather than paper over this for now we restrict ExtremumOp to +# integral types and fall back on float/double. +@pytest.mark.parametrize('lambda_sql, init_sql', [ + ('(acc, x) -> greatest(acc, x)', 'CAST("-Infinity" as DOUBLE)'), + ('(acc, x) -> least(acc, x)', 'CAST("Infinity" as DOUBLE)'), +], ids=['max', 'min']) +@disable_ansi_mode +@allow_non_gpu('ProjectExec') +def test_array_aggregate_double_extremum_falls_back(lambda_sql, init_sql): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, ArrayGen(double_gen, max_length=5)).selectExpr( + f'aggregate(a, {init_sql}, {lambda_sql}) as res'), + 'ArrayAggregate') diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index ebdef7e0716..fe21e7cb94a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -975,10 +975,19 @@ case object ProductOp extends AggOp { /** * MaxOp / MinOp share EXCLUDE null policy: Spark's Greatest / Least skip null operands, * so an all-null list reduces to null and then folds back to zero via identity substitution. + * + * Float / Double are unsupported: cuDF's segmented `max` / `min` follow IEEE 754, where + * `fmax(NaN, x) = x` (NaN is absorbed). Spark's `Greatest` / `Least` use `Double.compare`, + * which treats NaN as larger than every other value and propagates it. The `combineWithZero` + * compare + ifElse also breaks on NaN (`greaterThan(NaN, z) = false`). Until we add an + * explicit NaN-propagation step, restrict to integral types. */ sealed trait ExtremumOp extends AggOp { val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.EXCLUDE - def supportsType(t: DataType): Boolean = AggOp.isPlainNumeric(t) + def supportsType(t: DataType): Boolean = t match { + case ByteType | ShortType | IntegerType | LongType => true + case _ => false + } } case object MaxOp extends ExtremumOp { From dc720c177930297444958c8f6f9c0a6d678b7d95 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 23 Apr 2026 14:25:32 +0800 Subject: [PATCH 05/12] Drop unreachable Float/Double arms from MaxOp/MinOp identityScalar ExtremumOp.supportsType already rejects Float/Double, so the per-type cascade in MaxOp.identityScalar and MinOp.identityScalar will never see them today. Remove the dead arms to keep the code honest; they can be added back when a follow-up adds real NaN propagation. --- .../scala/com/nvidia/spark/rapids/higherOrderFunctions.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index fe21e7cb94a..56f6efb0cb4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -998,8 +998,6 @@ case object MaxOp extends ExtremumOp { case ShortType => cudf.Scalar.fromShort(Short.MinValue) case IntegerType => cudf.Scalar.fromInt(Int.MinValue) case LongType => cudf.Scalar.fromLong(Long.MinValue) - case FloatType => cudf.Scalar.fromFloat(Float.NegativeInfinity) - case DoubleType => cudf.Scalar.fromDouble(Double.NegativeInfinity) case other => throw new IllegalStateException(s"MAX identity not defined for $other") } // cuDF's NULL_MAX treats null as smallest (wrong for Spark), so emulate null-propagating @@ -1024,8 +1022,6 @@ case object MinOp extends ExtremumOp { case ShortType => cudf.Scalar.fromShort(Short.MaxValue) case IntegerType => cudf.Scalar.fromInt(Int.MaxValue) case LongType => cudf.Scalar.fromLong(Long.MaxValue) - case FloatType => cudf.Scalar.fromFloat(Float.PositiveInfinity) - case DoubleType => cudf.Scalar.fromDouble(Double.PositiveInfinity) case other => throw new IllegalStateException(s"MIN identity not defined for $other") } def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) From f414e86db3a29700ab41a9eeb6a3024fcc0bca81 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 23 Apr 2026 14:53:22 +0800 Subject: [PATCH 06/12] Fix GpuArrayAggregate.nullable to match Spark semantics Previously `nullable = argument.nullable`, which is incorrect when the outer list is non-nullable but its elements can be null. For INCLUDE-policy ops (SUM / PRODUCT), a null element anywhere in a non-null list poisons the accumulator and yields a null output row. Reporting nullable=false in that case can let the Spark optimizer elide null checks and cause silent wrong results downstream. Spark's own ArrayAggregate.nullable returns `argument.nullable || finish.nullable`, and the finish lambda's acc variable is always bound with nullable=true (see ArrayAggregate.bind's `zero.dataType -> true`), so the CPU side is effectively always true. Match that. --- .../com/nvidia/spark/rapids/higherOrderFunctions.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index 56f6efb0cb4..965bac61bae 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -1152,7 +1152,12 @@ case class GpuArrayAggregate( override def dataType: DataType = zero.dataType - override def nullable: Boolean = argument.nullable + // Matches Spark's ArrayAggregate.nullable = argument.nullable || finish.nullable. The + // finish lambda's accumulator variable is bound with nullable=true (Spark's + // ArrayAggregate.bind uses `zero.dataType -> true` for the acc slot), so the CPU side + // is effectively always true. Also covers the INCLUDE-policy case where a null element + // in a non-null list poisons the reduce and yields a null output row. + override def nullable: Boolean = true override def prettyName: String = "array_aggregate" From 8bb4eb3d6bc4a4af50f1918a3fffb8eba1c7f935 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 23 Apr 2026 15:28:34 +0800 Subject: [PATCH 07/12] Address review nits: outDType, decomposition.g, closeOnExcept, native int tests - AllOp / AnyOp combineWithZero now pass outDType to cuDF's and / or (ProductOp and SumOp were already doing this via add / mul). MaxOp / MinOp use ifElse, which has no outType argument; the output type there is determined by the inputs (both reduced and zero carry outDType already). - ArrayAggregateDecomposition now stores the g sub-expression directly instead of a gChildIndex. convertToGpuImpl locates the GPU g via fastEquals under the merge body's meta children rather than positional indexing, so we don't rely on the Add / Multiply / And / Or / Greatest / Least meta-children happening to be laid out as [left, right]. Decomposer unit tests assert on g identity. - Each val-chain boundary in columnarEval is now wrapped in closeOnExcept(x) { _ => withResource(x) { ... } } so the transitional window between a step's result being assigned and the next withResource taking ownership is covered. cuDF's ColumnVector.close is refcount-based, so the rare double-close on exception paths is benign. - Added a parametric native-integer integration test hitting int / long SUM, int MAX, and long MIN without the Cast-to-BIGINT that the existing numeric test uses, exercising identityScalar / combineWithZero on the primitive types directly. --- .../python/higher_order_functions_test.py | 19 +++++ .../spark/rapids/higherOrderFunctions.scala | 78 +++++++++++-------- .../ArrayAggregateDecomposerSuite.scala | 41 +++++----- 3 files changed, 85 insertions(+), 53 deletions(-) diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 7389d65950b..d24a73eae19 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -58,6 +58,25 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) +# Same ops exercised on the native element type (no Cast in the lambda body), so the +# identityScalar / combineWithZero paths for Int / Long are hit directly. Covers the +# INCLUDE-policy null-element propagation for SUM on a nullable element type too. +@pytest.mark.parametrize('gen, lambda_sql, init_sql', [ + (IntegerGen(min_val=-100, max_val=100), '(acc, x) -> acc + x', '0'), + (LongGen(min_val=-100, max_val=100), '(acc, x) -> acc + x', '0L'), + (IntegerGen(min_val=-100, max_val=100), + '(acc, x) -> greatest(acc, x)', 'CAST(-9999 as INT)'), + (LongGen(min_val=-100, max_val=100), + '(acc, x) -> least(acc, x)', '9223372036854775807L'), +], ids=['int-sum', 'long-sum', 'int-max', 'long-min']) +@disable_ansi_mode +def test_array_aggregate_native_integer_ops(gen, lambda_sql, init_sql): + def do_it(spark): + return unary_op_df(spark, ArrayGen(gen, max_length=8)).selectExpr( + f'aggregate(a, {init_sql}, {lambda_sql}) as res') + assert_gpu_and_cpu_are_equal_collect(do_it) + + # Happy path for the boolean ops. Elements must be non-null because cuDF's segmented ALL/ # ANY with INCLUDE nulls don't match Spark's AND/OR 3VL for mixed null+bool (specifically, # `false AND null = false` short-circuit; `true OR null = true`). The tag-time guard falls diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index 965bac61bae..803d4579873 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -1042,7 +1042,7 @@ case object AllOp extends AggOp { // INCLUDE: matches Spark's 3VL for AND (null AND true = null, null AND false = false). val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = cudf.Scalar.fromBool(true) - def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.and(z) + def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.and(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case a: And => Some((a.left, a.right)) case _ => None @@ -1055,7 +1055,7 @@ case object AnyOp extends AggOp { def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.any() val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = cudf.Scalar.fromBool(false) - def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.or(z) + def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.or(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case o: Or => Some((o.left, o.right)) case _ => None @@ -1069,13 +1069,16 @@ case object AnyOp extends AggOp { * registered AggOp. * * @param op the chosen aggregation operator - * @param gChildIndex 0 if `g` is the left child of the merge body's binary op, 1 if right + * @param g the Catalyst sub-expression corresponding to `g(x)` in the + * `(acc, x) -> op(acc, g(x))` rewrite — stored directly (rather than + * a child index) so convertToGpuImpl locates it by expression + * identity instead of relying on a meta-children ordering invariant * @param accVarExprId the accumulator NamedLambdaVariable's exprId * @param elemVar the element NamedLambdaVariable (used to build the g lambda) */ case class ArrayAggregateDecomposition( op: AggOp, - gChildIndex: Int, + g: Expression, accVarExprId: ExprId, elemVar: NamedLambdaVariable) @@ -1105,9 +1108,9 @@ object ArrayAggregateDecomposer { allOps.view.flatMap { op => op.matchBinary(body).flatMap { case (l, r) => if (isAccRef(l, accId) && !containsAccRef(r, accId)) { - Some(ArrayAggregateDecomposition(op, 1, accId, elemVar)) + Some(ArrayAggregateDecomposition(op, r, accId, elemVar)) } else if (isAccRef(r, accId) && !containsAccRef(l, accId)) { - Some(ArrayAggregateDecomposition(op, 0, accId, elemVar)) + Some(ArrayAggregateDecomposition(op, l, accId, elemVar)) } else None } }.headOption @@ -1218,8 +1221,14 @@ case class GpuArrayAggregate( // exploded batch (can be large for long arrays) is the main thing we want to let go // of as early as possible. - // Step 1: g(x) over children + segmented reduce. Releases cb, transformedData, - // listOfGView as soon as the reduced per-row scalar column is materialized. + // Each step chains via a `val x = closeOnExcept(...) { withResource(previous) { ... } }` + // idiom: closeOnExcept covers the tiny window between the previous step's result + // being assigned and `withResource` taking ownership, and the inner `withResource` + // ensures the previous step's column is released on both normal and exceptional + // paths. cuDF's ColumnVector.close is refcount-based so any late double-close on + // the rare exception path is benign. + + // Step 1: g(x) over children + segmented reduce. val reduced: cudf.ColumnVector = withResource(makeElementProjectBatch(batch, arg)) { cb => withResource(function.asInstanceOf[GpuExpression].columnarEval(cb)) { @@ -1231,28 +1240,33 @@ case class GpuArrayAggregate( } } - // Step 2: substitute op's identity for rows the reduce couldn't cover (per - // nullPolicy). Releases reduced, mask, and idScalar. - val adjusted: cudf.ColumnVector = withResource(reduced) { reduced => - withResource(substituteMask(arg.getBase, reduced)) { mask => - withResource(op.identityScalar(dataType, outDType)) { idScalar => - mask.ifElse(idScalar, reduced) + // Step 2: substitute op's identity for rows the reduce couldn't cover. + val adjusted: cudf.ColumnVector = closeOnExcept(reduced) { _ => + withResource(reduced) { reduced => + withResource(substituteMask(arg.getBase, reduced)) { mask => + withResource(op.identityScalar(dataType, outDType)) { idScalar => + mask.ifElse(idScalar, reduced) + } } } } - // Step 3: combine with zero. Releases adjusted and zeroCv. - val combined: cudf.ColumnVector = withResource(adjusted) { adjusted => - withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => - op.combineWithZero(adjusted, zeroCv.getBase, outDType) + // Step 3: combine with zero. + val combined: cudf.ColumnVector = closeOnExcept(adjusted) { _ => + withResource(adjusted) { adjusted => + withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => + op.combineWithZero(adjusted, zeroCv.getBase, outDType) + } } } // Step 4: restore null on rows where the input list itself was null. cuDF GREATER / // LOGICAL_AND / LOGICAL_OR don't propagate null the way Spark's 3VL would, so the // combine step alone can't preserve it. mergeNulls short-circuits if no nulls. - withResource(combined) { combined => - GpuColumnVector.from(NullUtilities.mergeNulls(combined, arg.getBase), dataType) + closeOnExcept(combined) { _ => + withResource(combined) { combined => + GpuColumnVector.from(NullUtilities.mergeNulls(combined, arg.getBase), dataType) + } } } } @@ -1290,14 +1304,10 @@ class GpuArrayAggregateMeta( } // g's output type must equal the accumulator/zero type so the segmented reduce output // matches the Spark-expected result type directly. - val body = expr.merge.asInstanceOf[LambdaFunction].function - val gType = decomp.op.matchBinary(body).get match { - case (_, r) if decomp.gChildIndex == 1 => r.dataType - case (l, _) => l.dataType - } - if (!DataType.equalsStructurally(gType, expr.zero.dataType, ignoreNullability = true)) { + if (!DataType.equalsStructurally( + decomp.g.dataType, expr.zero.dataType, ignoreNullability = true)) { willNotWorkOnGpu( - s"g(x) output type ($gType) does not match accumulator/zero type " + + s"g(x) output type (${decomp.g.dataType}) does not match accumulator/zero type " + s"(${expr.zero.dataType})") return } @@ -1324,12 +1334,16 @@ class GpuArrayAggregateMeta( val argGpu = childExprs.head.convertToGpu() val zeroGpu = childExprs(1).convertToGpu() - // childExprs(2) is the merge lambda meta; its first child is the op body meta, whose - // gChildIndex-th child is the g sub-expression. For binary catalyst shapes (Add, - // Multiply, And, Or) children are [left, right]; for variadic shapes restricted to - // size==2 (Greatest, Least) children are also [left, right]. So the index lines up. + // childExprs(2) is the merge lambda meta; its first child is the op body meta. Find + // the sub-meta whose wrapped CPU expression matches the g we recorded during + // decomposition, so we don't rely on meta-children ordering lining up with Catalyst's + // [left, right] convention. val bodyMeta = childExprs(2).childExprs.head - val gGpu = bodyMeta.childExprs(d.gChildIndex).convertToGpu() + val gMeta = bodyMeta.childExprs.find { + _.wrapped.asInstanceOf[Expression].fastEquals(d.g) + }.getOrElse(throw new IllegalStateException( + s"could not locate g sub-expression ${d.g} under merge body meta")) + val gGpu = gMeta.convertToGpu() val elemVarGpu = GpuNamedLambdaVariable( d.elemVar.name, d.elemVar.dataType, d.elemVar.nullable, d.elemVar.exprId) val gLambda = GpuLambdaFunction(gGpu, Seq(elemVarGpu)) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala index 1f9c249df47..f9257ddd873 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala @@ -54,11 +54,11 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { acc: NamedLambdaVariable, x: NamedLambdaVariable, expectedOp: AggOp, - expectedGChildIndex: Int): ArrayAggregateDecomposition = { + expectedG: Expression): ArrayAggregateDecomposition = { val d = decompose(merge(body, acc, x), identityFinish(acc)) assert(d.isDefined, s"expected decomposition for body=$body") assert(d.get.op == expectedOp) - assert(d.get.gChildIndex == expectedGChildIndex) + assert(d.get.g.fastEquals(expectedG), s"expected g=$expectedG, got ${d.get.g}") assert(d.get.accVarExprId == acc.exprId) assert(d.get.elemVar.exprId == x.exprId) d.get @@ -73,59 +73,58 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { // --- positive: one per op --------------------------------------------- - test("Add(acc, x) -> SUM, gChildIndex=1") { + test("Add(acc, x) -> SUM, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(plus(acc, x), acc, x, SumOp, 1) + assertDecomposes(plus(acc, x), acc, x, SumOp, x) } - test("Add(x, acc) (commuted) -> SUM, gChildIndex=0") { + test("Add(x, acc) (commuted) -> SUM, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(plus(x, acc), acc, x, SumOp, 0) + assertDecomposes(plus(x, acc), acc, x, SumOp, x) } - test("Multiply(acc, x) -> PRODUCT") { + test("Multiply(acc, x) -> PRODUCT, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(times(acc, x), acc, x, ProductOp, 1) + assertDecomposes(times(acc, x), acc, x, ProductOp, x) } - test("Greatest(acc, x) -> MAX") { + test("Greatest(acc, x) -> MAX, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(greatest(acc, x), acc, x, MaxOp, 1) + assertDecomposes(greatest(acc, x), acc, x, MaxOp, x) } - test("Least(acc, x) -> MIN") { + test("Least(acc, x) -> MIN, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(least(acc, x), acc, x, MinOp, 1) + assertDecomposes(least(acc, x), acc, x, MinOp, x) } - test("And(acc, x) -> ALL") { + test("And(acc, x) -> ALL, g = x") { val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) - assertDecomposes(And(acc, x), acc, x, AllOp, 1) + assertDecomposes(And(acc, x), acc, x, AllOp, x) } - test("Or(acc, x) -> ANY") { + test("Or(acc, x) -> ANY, g = x") { val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) - assertDecomposes(Or(acc, x), acc, x, AnyOp, 1) + assertDecomposes(Or(acc, x), acc, x, AnyOp, x) } // --- positive: structural variations ---------------------------------- - test("Complex g(x) with no acc ref still decomposes") { + test("Complex g(x) with no acc ref is captured verbatim") { val acc = lv("acc", LongType); val x = lv("x", IntegerType) - // g = Cast(x * 2 + 1, Long) val g = Cast(plus(times(x, Literal(2)), Literal(1)), LongType) - assertDecomposes(plus(acc, g), acc, x, SumOp, 1) + assertDecomposes(plus(acc, g), acc, x, SumOp, g) } test("Cast wrapping the acc side is unwrapped (single layer)") { val acc = lv("acc", LongType); val x = lv("x", IntegerType) - assertDecomposes(plus(Cast(acc, IntegerType), x), acc, x, SumOp, 1) + assertDecomposes(plus(Cast(acc, IntegerType), x), acc, x, SumOp, x) } test("Cast wrapping the acc side is unwrapped (chained)") { val acc = lv("acc"); val x = lv("x") val doubleCastAcc = Cast(Cast(acc, LongType), IntegerType) - assertDecomposes(plus(doubleCastAcc, x), acc, x, SumOp, 1) + assertDecomposes(plus(doubleCastAcc, x), acc, x, SumOp, x) } // --- negative: wrong shape -------------------------------------------- From 1d1ff231947a7298efe90f7f1b649181e3f3033b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 23 Apr 2026 17:19:06 +0800 Subject: [PATCH 08/12] doc generation Signed-off-by: Haoyang Li --- .../advanced_configs.md | 1 + docs/supported_ops.md | 578 +++++++++++------- tools/generated_files/330/operatorsScore.csv | 1 + tools/generated_files/330/supportedExprs.csv | 5 + tools/generated_files/331/operatorsScore.csv | 1 + tools/generated_files/331/supportedExprs.csv | 5 + tools/generated_files/332/operatorsScore.csv | 1 + tools/generated_files/332/supportedExprs.csv | 5 + tools/generated_files/333/operatorsScore.csv | 1 + tools/generated_files/333/supportedExprs.csv | 5 + tools/generated_files/334/operatorsScore.csv | 1 + tools/generated_files/334/supportedExprs.csv | 5 + tools/generated_files/340/operatorsScore.csv | 1 + tools/generated_files/340/supportedExprs.csv | 5 + tools/generated_files/341/operatorsScore.csv | 1 + tools/generated_files/341/supportedExprs.csv | 5 + tools/generated_files/342/operatorsScore.csv | 1 + tools/generated_files/342/supportedExprs.csv | 5 + tools/generated_files/343/operatorsScore.csv | 1 + tools/generated_files/343/supportedExprs.csv | 5 + tools/generated_files/344/operatorsScore.csv | 1 + tools/generated_files/344/supportedExprs.csv | 5 + tools/generated_files/350/operatorsScore.csv | 1 + tools/generated_files/350/supportedExprs.csv | 5 + tools/generated_files/351/operatorsScore.csv | 1 + tools/generated_files/351/supportedExprs.csv | 5 + tools/generated_files/353/operatorsScore.csv | 1 + tools/generated_files/353/supportedExprs.csv | 5 + tools/generated_files/354/operatorsScore.csv | 1 + tools/generated_files/354/supportedExprs.csv | 5 + tools/generated_files/355/operatorsScore.csv | 1 + tools/generated_files/355/supportedExprs.csv | 5 + tools/generated_files/356/operatorsScore.csv | 1 + tools/generated_files/356/supportedExprs.csv | 5 + tools/generated_files/357/operatorsScore.csv | 1 + tools/generated_files/357/supportedExprs.csv | 5 + tools/generated_files/358/operatorsScore.csv | 1 + tools/generated_files/358/supportedExprs.csv | 5 + tools/generated_files/400/operatorsScore.csv | 1 + tools/generated_files/400/supportedExprs.csv | 5 + tools/generated_files/401/operatorsScore.csv | 1 + tools/generated_files/401/supportedExprs.csv | 5 + tools/generated_files/402/operatorsScore.csv | 1 + tools/generated_files/402/supportedExprs.csv | 5 + tools/generated_files/411/operatorsScore.csv | 1 + tools/generated_files/411/supportedExprs.csv | 5 + tools/generated_files/operatorsScore.csv | 1 + tools/generated_files/supportedExprs.csv | 5 + 48 files changed, 488 insertions(+), 229 deletions(-) diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 3915309d509..22f495347e0 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -197,6 +197,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.Alias| |Gives a column a name|true|None| spark.rapids.sql.expression.And|`and`|Logical AND|true|None| spark.rapids.sql.expression.AnsiCast| |Convert a column of one type of data into another type|true|None| +spark.rapids.sql.expression.ArrayAggregate|`aggregate`|Aggregate elements in an array using an accumulator function and finishing transformation. Currently only lambdas of the form (acc, x) -> acc + g(x) with an identity finish are executed on the GPU; other shapes fall back to CPU.|true|None| spark.rapids.sql.expression.ArrayContains|`array_contains`|Returns a boolean if the array contains the passed in key|true|None| spark.rapids.sql.expression.ArrayDistinct|`array_distinct`|Removes duplicate values from the array|true|None| spark.rapids.sql.expression.ArrayExcept|`array_except`|Returns an array of the elements in array1 but not in array2, without duplicates|true|This is not 100% compatible with the Spark version because the GPU implementation treats -0.0 and 0.0 as equal, but the CPU implementation currently does not (see SPARK-39845). Also, Apache Spark 3.1.3 fixed issue SPARK-36741 where NaNs in these set like operators were not treated as being equal. We have chosen to break with compatibility for the older versions of Spark in this instance and handle NaNs the same as 3.1.3+| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 3a8a59d466c..d6cc3be9a2d 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -2357,6 +2357,154 @@ are limited. +ArrayAggregate +`aggregate` +Aggregate elements in an array using an accumulator function and finishing transformation. Currently only lambdas of the form (acc, x) -> acc + g(x) with an identity finish are executed on the GPU; other shapes fall back to CPU. +None +project +zero +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +NS +NS +NS +NS +NS +NS +NS +NS +NS + + +result +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +NS +NS +NS +NS +NS +NS +NS +NS +NS + + +finish +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +NS +NS +NS +NS +NS +NS +NS +NS +NS + + +merge +S +S +S +S +S +S +S +S +PS
UTC is only supported TZ for TIMESTAMP
+S +S +NS +NS +NS +NS +NS +NS +NS +NS +NS + + +argument + + + + + + + + + + + + + + +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH
+ + + + + + + +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + ArrayContains `array_contains` Returns a boolean if the array contains the passed in key @@ -2482,34 +2630,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - ArrayExcept `array_except` Returns an array of the elements in array1 but not in array2, without duplicates @@ -2806,6 +2926,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + ArrayJoin `array_join` Concatenates the elements of the given array using the delimiter and an optional string to replace nulls. If no value is set for nullReplacement, any null value is filtered. @@ -2903,34 +3051,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - ArrayMax `array_max` Returns the maximum value in the array @@ -3127,7 +3247,7 @@ are limited. NS NS NS -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH
NS NS NS @@ -3150,9 +3270,9 @@ are limited. S NS NS -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH
-PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH
-PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH
+NS +NS +PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH
NS NS NS @@ -3173,7 +3293,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, UDT, DAYTIME, YEARMONTH
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH
@@ -3255,6 +3375,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + ArrayTransform `transform` Transform elements in an array using the transform function. This is similar to a `map` in functional programming @@ -3329,34 +3477,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - ArrayUnion `array_union` Returns an array of the elements in the union of array1 and array2, without duplicates. @@ -3705,6 +3825,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + Asinh `asinh` Inverse hyperbolic sine @@ -3803,34 +3951,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - AtLeastNNonNulls Checks if number of non null/Nan values is greater than a given value @@ -4130,6 +4250,34 @@ are limited. S +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + BRound `bround` Round an expression to d decimal places using HALF_EVEN rounding mode @@ -4204,34 +4352,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - Bin `bin` Returns the string representation of the long value `expr` represented in binary @@ -4529,6 +4649,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + BitwiseNot `~` Returns the bitwise NOT of the operands @@ -4627,34 +4775,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - BitwiseOr `\|` Returns the bitwise OR of the operands @@ -4943,6 +5063,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + BloomFilterMightContain Bloom filter query @@ -5017,34 +5165,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - BoundReference Reference to a bound variable @@ -5371,6 +5491,34 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT +DAYTIME +YEARMONTH + + Coalesce `coalesce` Returns the first non-null argument if exists. Otherwise, null @@ -5422,34 +5570,6 @@ are limited. S -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT -DAYTIME -YEARMONTH - - Concat `concat` List/String concatenate diff --git a/tools/generated_files/330/operatorsScore.csv b/tools/generated_files/330/operatorsScore.csv index 0708fa821bc..b3794066a01 100644 --- a/tools/generated_files/330/operatorsScore.csv +++ b/tools/generated_files/330/operatorsScore.csv @@ -50,6 +50,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/330/supportedExprs.csv b/tools/generated_files/330/supportedExprs.csv index 77b62bcecd8..1aae119c31c 100644 --- a/tools/generated_files/330/supportedExprs.csv +++ b/tools/generated_files/330/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/331/operatorsScore.csv b/tools/generated_files/331/operatorsScore.csv index a8f66bbf1e2..d2abe472782 100644 --- a/tools/generated_files/331/operatorsScore.csv +++ b/tools/generated_files/331/operatorsScore.csv @@ -50,6 +50,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/331/supportedExprs.csv b/tools/generated_files/331/supportedExprs.csv index c7913cbb3d9..a501a1714cc 100644 --- a/tools/generated_files/331/supportedExprs.csv +++ b/tools/generated_files/331/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/332/operatorsScore.csv b/tools/generated_files/332/operatorsScore.csv index a8f66bbf1e2..d2abe472782 100644 --- a/tools/generated_files/332/operatorsScore.csv +++ b/tools/generated_files/332/operatorsScore.csv @@ -50,6 +50,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/332/supportedExprs.csv b/tools/generated_files/332/supportedExprs.csv index c7913cbb3d9..a501a1714cc 100644 --- a/tools/generated_files/332/supportedExprs.csv +++ b/tools/generated_files/332/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/333/operatorsScore.csv b/tools/generated_files/333/operatorsScore.csv index a8f66bbf1e2..d2abe472782 100644 --- a/tools/generated_files/333/operatorsScore.csv +++ b/tools/generated_files/333/operatorsScore.csv @@ -50,6 +50,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/333/supportedExprs.csv b/tools/generated_files/333/supportedExprs.csv index c7913cbb3d9..a501a1714cc 100644 --- a/tools/generated_files/333/supportedExprs.csv +++ b/tools/generated_files/333/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/334/operatorsScore.csv b/tools/generated_files/334/operatorsScore.csv index a8f66bbf1e2..d2abe472782 100644 --- a/tools/generated_files/334/operatorsScore.csv +++ b/tools/generated_files/334/operatorsScore.csv @@ -50,6 +50,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/334/supportedExprs.csv b/tools/generated_files/334/supportedExprs.csv index c7913cbb3d9..a501a1714cc 100644 --- a/tools/generated_files/334/supportedExprs.csv +++ b/tools/generated_files/334/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/340/operatorsScore.csv b/tools/generated_files/340/operatorsScore.csv index 0eb2ed2c9d1..adc973373e9 100644 --- a/tools/generated_files/340/operatorsScore.csv +++ b/tools/generated_files/340/operatorsScore.csv @@ -51,6 +51,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/340/supportedExprs.csv b/tools/generated_files/340/supportedExprs.csv index 482270f0d0e..b72deede8ed 100644 --- a/tools/generated_files/340/supportedExprs.csv +++ b/tools/generated_files/340/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/341/operatorsScore.csv b/tools/generated_files/341/operatorsScore.csv index 0eb2ed2c9d1..adc973373e9 100644 --- a/tools/generated_files/341/operatorsScore.csv +++ b/tools/generated_files/341/operatorsScore.csv @@ -51,6 +51,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/341/supportedExprs.csv b/tools/generated_files/341/supportedExprs.csv index 482270f0d0e..b72deede8ed 100644 --- a/tools/generated_files/341/supportedExprs.csv +++ b/tools/generated_files/341/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/342/operatorsScore.csv b/tools/generated_files/342/operatorsScore.csv index 0eb2ed2c9d1..adc973373e9 100644 --- a/tools/generated_files/342/operatorsScore.csv +++ b/tools/generated_files/342/operatorsScore.csv @@ -51,6 +51,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/342/supportedExprs.csv b/tools/generated_files/342/supportedExprs.csv index 482270f0d0e..b72deede8ed 100644 --- a/tools/generated_files/342/supportedExprs.csv +++ b/tools/generated_files/342/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/343/operatorsScore.csv b/tools/generated_files/343/operatorsScore.csv index 0eb2ed2c9d1..adc973373e9 100644 --- a/tools/generated_files/343/operatorsScore.csv +++ b/tools/generated_files/343/operatorsScore.csv @@ -51,6 +51,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/343/supportedExprs.csv b/tools/generated_files/343/supportedExprs.csv index 482270f0d0e..b72deede8ed 100644 --- a/tools/generated_files/343/supportedExprs.csv +++ b/tools/generated_files/343/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/344/operatorsScore.csv b/tools/generated_files/344/operatorsScore.csv index 0eb2ed2c9d1..adc973373e9 100644 --- a/tools/generated_files/344/operatorsScore.csv +++ b/tools/generated_files/344/operatorsScore.csv @@ -51,6 +51,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/344/supportedExprs.csv b/tools/generated_files/344/supportedExprs.csv index 482270f0d0e..b72deede8ed 100644 --- a/tools/generated_files/344/supportedExprs.csv +++ b/tools/generated_files/344/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/350/operatorsScore.csv b/tools/generated_files/350/operatorsScore.csv index adb17e0c312..2d3f273c462 100644 --- a/tools/generated_files/350/operatorsScore.csv +++ b/tools/generated_files/350/operatorsScore.csv @@ -58,6 +58,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/350/supportedExprs.csv b/tools/generated_files/350/supportedExprs.csv index 6f70ef61b39..aad208ec635 100644 --- a/tools/generated_files/350/supportedExprs.csv +++ b/tools/generated_files/350/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/351/operatorsScore.csv b/tools/generated_files/351/operatorsScore.csv index adb17e0c312..2d3f273c462 100644 --- a/tools/generated_files/351/operatorsScore.csv +++ b/tools/generated_files/351/operatorsScore.csv @@ -58,6 +58,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/351/supportedExprs.csv b/tools/generated_files/351/supportedExprs.csv index 6f70ef61b39..aad208ec635 100644 --- a/tools/generated_files/351/supportedExprs.csv +++ b/tools/generated_files/351/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/353/operatorsScore.csv b/tools/generated_files/353/operatorsScore.csv index 823d05ea694..06525f664cb 100644 --- a/tools/generated_files/353/operatorsScore.csv +++ b/tools/generated_files/353/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/353/supportedExprs.csv b/tools/generated_files/353/supportedExprs.csv index 6f70ef61b39..aad208ec635 100644 --- a/tools/generated_files/353/supportedExprs.csv +++ b/tools/generated_files/353/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/354/operatorsScore.csv b/tools/generated_files/354/operatorsScore.csv index 823d05ea694..06525f664cb 100644 --- a/tools/generated_files/354/operatorsScore.csv +++ b/tools/generated_files/354/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/354/supportedExprs.csv b/tools/generated_files/354/supportedExprs.csv index 6f70ef61b39..aad208ec635 100644 --- a/tools/generated_files/354/supportedExprs.csv +++ b/tools/generated_files/354/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/355/operatorsScore.csv b/tools/generated_files/355/operatorsScore.csv index 823d05ea694..06525f664cb 100644 --- a/tools/generated_files/355/operatorsScore.csv +++ b/tools/generated_files/355/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/355/supportedExprs.csv b/tools/generated_files/355/supportedExprs.csv index 6f70ef61b39..aad208ec635 100644 --- a/tools/generated_files/355/supportedExprs.csv +++ b/tools/generated_files/355/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/356/operatorsScore.csv b/tools/generated_files/356/operatorsScore.csv index 823d05ea694..06525f664cb 100644 --- a/tools/generated_files/356/operatorsScore.csv +++ b/tools/generated_files/356/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/356/supportedExprs.csv b/tools/generated_files/356/supportedExprs.csv index 6f70ef61b39..aad208ec635 100644 --- a/tools/generated_files/356/supportedExprs.csv +++ b/tools/generated_files/356/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/357/operatorsScore.csv b/tools/generated_files/357/operatorsScore.csv index 823d05ea694..06525f664cb 100644 --- a/tools/generated_files/357/operatorsScore.csv +++ b/tools/generated_files/357/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/357/supportedExprs.csv b/tools/generated_files/357/supportedExprs.csv index 6f70ef61b39..aad208ec635 100644 --- a/tools/generated_files/357/supportedExprs.csv +++ b/tools/generated_files/357/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/358/operatorsScore.csv b/tools/generated_files/358/operatorsScore.csv index 823d05ea694..06525f664cb 100644 --- a/tools/generated_files/358/operatorsScore.csv +++ b/tools/generated_files/358/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/358/supportedExprs.csv b/tools/generated_files/358/supportedExprs.csv index 6f70ef61b39..aad208ec635 100644 --- a/tools/generated_files/358/supportedExprs.csv +++ b/tools/generated_files/358/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/400/operatorsScore.csv b/tools/generated_files/400/operatorsScore.csv index b232ffeb8ed..89a2c26d3ae 100644 --- a/tools/generated_files/400/operatorsScore.csv +++ b/tools/generated_files/400/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/400/supportedExprs.csv b/tools/generated_files/400/supportedExprs.csv index 8cbd9dfe053..49a7efa1dd9 100644 --- a/tools/generated_files/400/supportedExprs.csv +++ b/tools/generated_files/400/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/401/operatorsScore.csv b/tools/generated_files/401/operatorsScore.csv index 2c7d4847e8a..c98bf4f3614 100644 --- a/tools/generated_files/401/operatorsScore.csv +++ b/tools/generated_files/401/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/401/supportedExprs.csv b/tools/generated_files/401/supportedExprs.csv index 701fc837634..b7f349c0da7 100644 --- a/tools/generated_files/401/supportedExprs.csv +++ b/tools/generated_files/401/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/402/operatorsScore.csv b/tools/generated_files/402/operatorsScore.csv index 2c7d4847e8a..c98bf4f3614 100644 --- a/tools/generated_files/402/operatorsScore.csv +++ b/tools/generated_files/402/operatorsScore.csv @@ -59,6 +59,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/402/supportedExprs.csv b/tools/generated_files/402/supportedExprs.csv index 701fc837634..b7f349c0da7 100644 --- a/tools/generated_files/402/supportedExprs.csv +++ b/tools/generated_files/402/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/411/operatorsScore.csv b/tools/generated_files/411/operatorsScore.csv index ed0aaad355b..87b8a403879 100644 --- a/tools/generated_files/411/operatorsScore.csv +++ b/tools/generated_files/411/operatorsScore.csv @@ -60,6 +60,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/411/supportedExprs.csv b/tools/generated_files/411/supportedExprs.csv index 7451a7072c3..d8e1a3e607b 100644 --- a/tools/generated_files/411/supportedExprs.csv +++ b/tools/generated_files/411/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`; `reduce`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`; `reduce`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA diff --git a/tools/generated_files/operatorsScore.csv b/tools/generated_files/operatorsScore.csv index 0708fa821bc..b3794066a01 100644 --- a/tools/generated_files/operatorsScore.csv +++ b/tools/generated_files/operatorsScore.csv @@ -50,6 +50,7 @@ AggregateExpression,4 Alias,4 And,4 ApproximatePercentile,4 +ArrayAggregate,4 ArrayContains,4 ArrayDistinct,4 ArrayExcept,4 diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv index 77b62bcecd8..1aae119c31c 100644 --- a/tools/generated_files/supportedExprs.csv +++ b/tools/generated_files/supportedExprs.csv @@ -27,6 +27,11 @@ And,S,`and`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,N And,S,`and`,None,AST,lhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,rhs,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA And,S,`and`,None,AST,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +ArrayAggregate,S,`aggregate`,None,project,zero,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,finish,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,merge,S,S,S,S,S,S,S,S,PS,S,S,NS,NS,NS,NS,NS,NS,NS,NS,NS +ArrayAggregate,S,`aggregate`,None,project,argument,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,array,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA ArrayContains,S,`array_contains`,None,project,key,S,S,S,S,S,S,S,S,PS,S,NS,NS,NS,NS,NS,NS,NS,NS,NS,NS ArrayContains,S,`array_contains`,None,project,result,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA From a511620235edd0681d21bf3420f5f0a8efabc30b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 28 Apr 2026 18:17:06 +0800 Subject: [PATCH 09/12] simplify Signed-off-by: Haoyang Li --- .../python/higher_order_functions_test.py | 8 +- .../spark/rapids/higherOrderFunctions.scala | 113 +++++++----------- .../ArrayAggregateDecomposerSuite.scala | 62 ++++------ 3 files changed, 69 insertions(+), 114 deletions(-) diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index d24a73eae19..6acd954194d 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -109,7 +109,7 @@ def test_array_aggregate_boolean_ops_nullable_elements_fallback(lambda_sql, init 'ArrayAggregate') -# Count-if pattern (structural twin of the client's real workload). +# Count-if pattern: aggregate(array, 0, (acc, x) -> acc + CASE WHEN ... THEN 1 ELSE 0 END). @disable_ansi_mode def test_array_aggregate_count_if_int(): assert_gpu_and_cpu_are_equal_collect( @@ -118,9 +118,9 @@ def test_array_aggregate_count_if_int(): 'aggregate(a, 0L, (acc, x) -> acc + CAST(CASE WHEN x IS NULL THEN 1 ELSE 0 END as BIGINT)) as null_cnt')) -# Client's actual pattern (simplified): filter + aggregate with split / GetArrayItem / IN. +# Composed pattern: filter + aggregate with split / GetArrayItem / IN inside the lambda. @disable_ansi_mode -def test_array_aggregate_client_pattern(): +def test_array_aggregate_with_filter_and_split(): field_gen = StringGen('[a-z]{2}') def do_it(spark): df = unary_op_df(spark, ArrayGen(field_gen, max_length=5)) @@ -134,7 +134,7 @@ def do_it(spark): AND NOT split(z, ' ', -1)[1] IN ('xx', 'yy') ) THEN 1 ELSE 0 END as BIGINT), id -> id - ) as client_cnt""") + ) as cnt""") assert_gpu_and_cpu_are_equal_collect(do_it) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index 803d4579873..f781ee56d28 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -28,7 +28,7 @@ import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.{Add, And, ArrayAggregate, Attribute, AttributeReference, AttributeSeq, Cast, Expression, ExprId, Greatest, LambdaFunction, Least, Multiply, NamedExpression, NamedLambdaVariable, Or} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, ShortType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, NumericType, ShortType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch /** @@ -905,8 +905,8 @@ sealed trait AggOp { def name: String def cudfAgg: cudf.SegmentedReductionAggregation def nullPolicy: cudf.NullPolicy - /** Identity scalar, built with `cudfDType` so ifElse / binaryOp don't hit width mismatch. */ - def identityScalar(sparkType: DataType, cudfDType: DType): cudf.Scalar + /** Identity scalar typed to match `t` so ifElse / binaryOp don't hit width mismatch. */ + def identityScalar(t: DataType): cudf.Scalar /** `reduced OP zero`, typed to outDType, with Spark-matching null propagation. */ def combineWithZero( reduced: cudf.ColumnVector, @@ -917,21 +917,13 @@ sealed trait AggOp { def supportsType(sparkType: DataType): Boolean } -object AggOp { - /** Numeric types that all basic cuDF reductions (sum/product/max/min) accept. */ - def isPlainNumeric(t: DataType): Boolean = t match { - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true - case _ => false - } -} - case object SumOp extends AggOp { val name = "SUM" def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.sum() // INCLUDE: Spark iteratively computes `acc + x` and null poisons the accumulator, so // one null element anywhere in the list yields null. val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE - def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { + def identityScalar(t: DataType): cudf.Scalar = t match { case ByteType => cudf.Scalar.fromByte(0.toByte) case ShortType => cudf.Scalar.fromShort(0.toShort) case IntegerType => cudf.Scalar.fromInt(0) @@ -946,8 +938,7 @@ case object SumOp extends AggOp { case a: Add => Some((a.left, a.right)) case _ => None } - def supportsType(t: DataType): Boolean = - AggOp.isPlainNumeric(t) || t.isInstanceOf[DecimalType] + def supportsType(t: DataType): Boolean = t.isInstanceOf[NumericType] } case object ProductOp extends AggOp { @@ -955,7 +946,7 @@ case object ProductOp extends AggOp { def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.product() val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE - def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { + def identityScalar(t: DataType): cudf.Scalar = t match { case ByteType => cudf.Scalar.fromByte(1.toByte) case ShortType => cudf.Scalar.fromShort(1.toShort) case IntegerType => cudf.Scalar.fromInt(1) @@ -969,21 +960,29 @@ case object ProductOp extends AggOp { case m: Multiply => Some((m.left, m.right)) case _ => None } - def supportsType(t: DataType): Boolean = AggOp.isPlainNumeric(t) + // Decimal would need DecimalUtils.multiplyDecimals for overflow handling — exclude for now. + def supportsType(t: DataType): Boolean = t match { + case _: NumericType => !t.isInstanceOf[DecimalType] + case _ => false + } } /** - * MaxOp / MinOp share EXCLUDE null policy: Spark's Greatest / Least skip null operands, - * so an all-null list reduces to null and then folds back to zero via identity substitution. + * MaxOp / MinOp share EXCLUDE null policy: Spark's Greatest / Least skip null operands. + * combineWithZero uses cuDF's NULL_MAX / NULL_MIN (the same primitive GpuGreatest/GpuLeast + * use), which returns the non-null operand when one side is null — exactly Spark's + * behavior on integral types. * * Float / Double are unsupported: cuDF's segmented `max` / `min` follow IEEE 754, where * `fmax(NaN, x) = x` (NaN is absorbed). Spark's `Greatest` / `Least` use `Double.compare`, - * which treats NaN as larger than every other value and propagates it. The `combineWithZero` - * compare + ifElse also breaks on NaN (`greaterThan(NaN, z) = false`). Until we add an + * which treats NaN as larger than every other value and propagates it. Until we add an * explicit NaN-propagation step, restrict to integral types. */ sealed trait ExtremumOp extends AggOp { val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.EXCLUDE + def binaryOp: cudf.BinaryOp + def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType): cudf.ColumnVector = + r.binaryOp(binaryOp, z, out) def supportsType(t: DataType): Boolean = t match { case ByteType | ShortType | IntegerType | LongType => true case _ => false @@ -993,21 +992,14 @@ sealed trait ExtremumOp extends AggOp { case object MaxOp extends ExtremumOp { val name = "MAX" def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.max() - def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { + val binaryOp: cudf.BinaryOp = cudf.BinaryOp.NULL_MAX + def identityScalar(t: DataType): cudf.Scalar = t match { case ByteType => cudf.Scalar.fromByte(Byte.MinValue) case ShortType => cudf.Scalar.fromShort(Short.MinValue) case IntegerType => cudf.Scalar.fromInt(Int.MinValue) case LongType => cudf.Scalar.fromLong(Long.MinValue) case other => throw new IllegalStateException(s"MAX identity not defined for $other") } - // cuDF's NULL_MAX treats null as smallest (wrong for Spark), so emulate null-propagating - // max via compare + ifElse; null in the compare's result propagates through ifElse. - def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) - : cudf.ColumnVector = { - withResource(r.greaterThan(z)) { rGreater => - rGreater.ifElse(r, z) - } - } def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case g: Greatest if g.children.size == 2 => Some((g.children.head, g.children(1))) case _ => None @@ -1017,19 +1009,14 @@ case object MaxOp extends ExtremumOp { case object MinOp extends ExtremumOp { val name = "MIN" def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.min() - def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = t match { + val binaryOp: cudf.BinaryOp = cudf.BinaryOp.NULL_MIN + def identityScalar(t: DataType): cudf.Scalar = t match { case ByteType => cudf.Scalar.fromByte(Byte.MaxValue) case ShortType => cudf.Scalar.fromShort(Short.MaxValue) case IntegerType => cudf.Scalar.fromInt(Int.MaxValue) case LongType => cudf.Scalar.fromLong(Long.MaxValue) case other => throw new IllegalStateException(s"MIN identity not defined for $other") } - def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) - : cudf.ColumnVector = { - withResource(r.lessThan(z)) { rLess => - rLess.ifElse(r, z) - } - } def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case l: Least if l.children.size == 2 => Some((l.children.head, l.children(1))) case _ => None @@ -1041,7 +1028,7 @@ case object AllOp extends AggOp { def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.all() // INCLUDE: matches Spark's 3VL for AND (null AND true = null, null AND false = false). val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE - def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = cudf.Scalar.fromBool(true) + def identityScalar(t: DataType): cudf.Scalar = cudf.Scalar.fromBool(true) def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.and(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case a: And => Some((a.left, a.right)) @@ -1054,7 +1041,7 @@ case object AnyOp extends AggOp { val name = "ANY" def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.any() val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE - def identityScalar(t: DataType, cudfT: DType): cudf.Scalar = cudf.Scalar.fromBool(false) + def identityScalar(t: DataType): cudf.Scalar = cudf.Scalar.fromBool(false) def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.or(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case o: Or => Some((o.left, o.right)) @@ -1207,27 +1194,19 @@ case class GpuArrayAggregate( reduced.isNull } // Exclude null-list rows from the mask so the final null-restoration step handles them. - // For non-nullable columns this is effectively a no-op (isNotNull is all-true). - withResource(reducedIsEmpty) { m => - withResource(listCol.isNotNull) { isNotNull => m.and(isNotNull) } + // Skip when the input list has no nulls — `m.and(all-true)` is a wasted kernel. + if (listCol.getNullCount > 0) { + withResource(reducedIsEmpty) { m => + withResource(listCol.isNotNull) { isNotNull => m.and(isNotNull) } + } + } else { + reducedIsEmpty } } override def columnarEval(batch: ColumnarBatch): GpuColumnVector = { val outDType = GpuColumnVector.getNonNestedRapidsType(dataType) withResource(argument.asInstanceOf[GpuExpression].columnarEval(batch)) { arg => - // Each step chains via `val x = withResource(...) { ... }` so the previous stage's - // intermediate GPU columns are released before the next stage allocates more. The - // exploded batch (can be large for long arrays) is the main thing we want to let go - // of as early as possible. - - // Each step chains via a `val x = closeOnExcept(...) { withResource(previous) { ... } }` - // idiom: closeOnExcept covers the tiny window between the previous step's result - // being assigned and `withResource` taking ownership, and the inner `withResource` - // ensures the previous step's column is released on both normal and exceptional - // paths. cuDF's ColumnVector.close is refcount-based so any late double-close on - // the rare exception path is benign. - // Step 1: g(x) over children + segmented reduce. val reduced: cudf.ColumnVector = withResource(makeElementProjectBatch(batch, arg)) { cb => @@ -1241,32 +1220,30 @@ case class GpuArrayAggregate( } // Step 2: substitute op's identity for rows the reduce couldn't cover. - val adjusted: cudf.ColumnVector = closeOnExcept(reduced) { _ => - withResource(reduced) { reduced => - withResource(substituteMask(arg.getBase, reduced)) { mask => - withResource(op.identityScalar(dataType, outDType)) { idScalar => - mask.ifElse(idScalar, reduced) - } + val adjusted: cudf.ColumnVector = withResource(reduced) { reduced => + withResource(substituteMask(arg.getBase, reduced)) { mask => + withResource(op.identityScalar(dataType)) { idScalar => + mask.ifElse(idScalar, reduced) } } } // Step 3: combine with zero. - val combined: cudf.ColumnVector = closeOnExcept(adjusted) { _ => - withResource(adjusted) { adjusted => - withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => - op.combineWithZero(adjusted, zeroCv.getBase, outDType) - } + val combined: cudf.ColumnVector = withResource(adjusted) { adjusted => + withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => + op.combineWithZero(adjusted, zeroCv.getBase, outDType) } } - // Step 4: restore null on rows where the input list itself was null. cuDF GREATER / - // LOGICAL_AND / LOGICAL_OR don't propagate null the way Spark's 3VL would, so the - // combine step alone can't preserve it. mergeNulls short-circuits if no nulls. - closeOnExcept(combined) { _ => + // Step 4: restore null on rows where the input list itself was null. cuDF NULL_MAX / + // NULL_MIN / LOGICAL_AND / LOGICAL_OR don't propagate null the way Spark's 3VL would, + // so the combine step alone can't preserve it. Skip outright when the list has no nulls. + if (arg.getBase.getNullCount > 0) { withResource(combined) { combined => GpuColumnVector.from(NullUtilities.mergeNulls(combined, arg.getBase), dataType) } + } else { + GpuColumnVector.from(combined, dataType) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala index f9257ddd873..ad6339619ab 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, LongType} class ArrayAggregateDecomposerSuite extends GpuUnitTests { import ArrayAggregateDecomposer.decompose - // --- helpers ----------------------------------------------------------- - private def lv(name: String, dt: DataType = IntegerType): NamedLambdaVariable = NamedLambdaVariable(name, dt, nullable = true, exprId = NamedExpression.newExprId) @@ -41,14 +39,6 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { private def identityFinish(acc: NamedLambdaVariable): LambdaFunction = LambdaFunction(acc, Seq(acc)) - private def plus(l: Expression, r: Expression): Add = Add(l, r) - private def minus(l: Expression, r: Expression): Subtract = Subtract(l, r) - private def times(l: Expression, r: Expression): Multiply = Multiply(l, r) - private def div(l: Expression, r: Expression): Divide = Divide(l, r) - private def greatest(l: Expression, r: Expression): Greatest = Greatest(Seq(l, r)) - private def least(l: Expression, r: Expression): Least = Least(Seq(l, r)) - - /** Assert decomposition succeeds; returns the ArrayAggregateDecomposition for further checks. */ private def assertDecomposes( body: Expression, acc: NamedLambdaVariable, @@ -71,31 +61,29 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { assert(decompose(mergeBody, finish).isEmpty, reason) } - // --- positive: one per op --------------------------------------------- - test("Add(acc, x) -> SUM, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(plus(acc, x), acc, x, SumOp, x) + assertDecomposes(Add(acc, x), acc, x, SumOp, x) } test("Add(x, acc) (commuted) -> SUM, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(plus(x, acc), acc, x, SumOp, x) + assertDecomposes(Add(x, acc), acc, x, SumOp, x) } test("Multiply(acc, x) -> PRODUCT, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(times(acc, x), acc, x, ProductOp, x) + assertDecomposes(Multiply(acc, x), acc, x, ProductOp, x) } test("Greatest(acc, x) -> MAX, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(greatest(acc, x), acc, x, MaxOp, x) + assertDecomposes(Greatest(Seq(acc, x)), acc, x, MaxOp, x) } test("Least(acc, x) -> MIN, g = x") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(least(acc, x), acc, x, MinOp, x) + assertDecomposes(Least(Seq(acc, x)), acc, x, MinOp, x) } test("And(acc, x) -> ALL, g = x") { @@ -108,36 +96,32 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { assertDecomposes(Or(acc, x), acc, x, AnyOp, x) } - // --- positive: structural variations ---------------------------------- - test("Complex g(x) with no acc ref is captured verbatim") { val acc = lv("acc", LongType); val x = lv("x", IntegerType) - val g = Cast(plus(times(x, Literal(2)), Literal(1)), LongType) - assertDecomposes(plus(acc, g), acc, x, SumOp, g) + val g = Cast(Add(Multiply(x, Literal(2)), Literal(1)), LongType) + assertDecomposes(Add(acc, g), acc, x, SumOp, g) } test("Cast wrapping the acc side is unwrapped (single layer)") { val acc = lv("acc", LongType); val x = lv("x", IntegerType) - assertDecomposes(plus(Cast(acc, IntegerType), x), acc, x, SumOp, x) + assertDecomposes(Add(Cast(acc, IntegerType), x), acc, x, SumOp, x) } test("Cast wrapping the acc side is unwrapped (chained)") { val acc = lv("acc"); val x = lv("x") val doubleCastAcc = Cast(Cast(acc, LongType), IntegerType) - assertDecomposes(plus(doubleCastAcc, x), acc, x, SumOp, x) + assertDecomposes(Add(doubleCastAcc, x), acc, x, SumOp, x) } - // --- negative: wrong shape -------------------------------------------- - test("Subtract is not an associative op we recognize") { val acc = lv("acc"); val x = lv("x") - assertRejects(merge(minus(acc, x), acc, x), identityFinish(acc), + assertRejects(merge(Subtract(acc, x), acc, x), identityFinish(acc), "Subtract is not in the registered AggOps") } test("Divide is not an associative op we recognize") { val acc = lv("acc"); val x = lv("x") - assertRejects(merge(div(acc, x), acc, x), identityFinish(acc), + assertRejects(merge(Divide(acc, x), acc, x), identityFinish(acc), "Divide is not in the registered AggOps") } @@ -150,31 +134,27 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { test("g that references acc is rejected") { val acc = lv("acc"); val x = lv("x") - // g = acc * x, references acc - assertRejects(merge(plus(acc, times(acc, x)), acc, x), identityFinish(acc), + assertRejects(merge(Add(acc, Multiply(acc, x)), acc, x), identityFinish(acc), "g must not reference acc") } test("both sides reference acc is rejected") { val acc = lv("acc"); val x = lv("x") - assertRejects(merge(plus(acc, acc), acc, x), identityFinish(acc), + assertRejects(merge(Add(acc, acc), acc, x), identityFinish(acc), "neither side is a 'pure non-acc'") } test("neither side is a pure acc ref is rejected") { val acc = lv("acc"); val x = lv("x") - // body = (acc + 1) + x - assertRejects(merge(plus(plus(acc, Literal(1)), x), acc, x), identityFinish(acc), + assertRejects(merge(Add(Add(acc, Literal(1)), x), acc, x), identityFinish(acc), "left side isn't a naked acc ref") } - // --- negative: finish lambda ------------------------------------------ - test("non-identity finish is rejected") { val acc = lv("acc"); val x = lv("x") val finishAcc = lv("finishAcc") - val badFinish = LambdaFunction(plus(finishAcc, Literal(1)), Seq(finishAcc)) - assertRejects(merge(plus(acc, x), acc, x), badFinish, + val badFinish = LambdaFunction(Add(finishAcc, Literal(1)), Seq(finishAcc)) + assertRejects(merge(Add(acc, x), acc, x), badFinish, "finish that multiplies the accumulator isn't identity") } @@ -183,25 +163,23 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { val finishAcc = lv("finishAcc") val otherVar = lv("other") val badFinish = LambdaFunction(otherVar, Seq(finishAcc)) - assertRejects(merge(plus(acc, x), acc, x), badFinish, + assertRejects(merge(Add(acc, x), acc, x), badFinish, "finish body refers to a variable that isn't its own arg") } - // --- negative: shape sanity -------------------------------------------- - test("merge with wrong arg count is rejected") { val acc = lv("acc"); val x = lv("x"); val extra = lv("extra") - val body = LambdaFunction(plus(acc, x), Seq(acc, x, extra)) + val body = LambdaFunction(Add(acc, x), Seq(acc, x, extra)) assertRejects(body, identityFinish(acc), "merge must take 2 lambda args") } test("merge that isn't a LambdaFunction at all is rejected") { val acc = lv("acc") - assert(decompose(plus(Literal(1), Literal(2)), identityFinish(acc)).isEmpty) + assert(decompose(Add(Literal(1), Literal(2)), identityFinish(acc)).isEmpty) } test("finish that isn't a LambdaFunction is rejected") { val acc = lv("acc"); val x = lv("x") - assert(decompose(merge(plus(acc, x), acc, x), Literal(0)).isEmpty) + assert(decompose(merge(Add(acc, x), acc, x), Literal(0)).isEmpty) } } From fe66642949638f25af6b1825a88103a539661a41 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 29 Apr 2026 10:07:54 +0800 Subject: [PATCH 10/12] refactor Signed-off-by: Haoyang Li --- .../spark/rapids/higherOrderFunctions.scala | 186 ++++++++++-------- .../ArrayAggregateDecomposerSuite.scala | 108 +++++++--- 2 files changed, 182 insertions(+), 112 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index f781ee56d28..5f61f60b031 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -213,6 +213,12 @@ trait GpuSimpleHigherOrderFunction extends GpuHigherOrderFunction with GpuBind { } +/** + * Common explode + lambda projection plumbing for higher-order functions over arrays. + * Subclasses choose how to consume the lambda's per-element result by either extending + * GpuArrayElementWiseTransform (one row in -> one row out via transformListColumnView) or + * implementing columnarEval themselves (e.g. segmented reductions like GpuArrayAggregate). + */ trait GpuArrayTransformBase extends GpuSimpleHigherOrderFunction { def isBound: Boolean def boundIntermediate: Seq[GpuExpression] @@ -275,6 +281,14 @@ trait GpuArrayTransformBase extends GpuSimpleHigherOrderFunction { } } +} + +/** + * Specialization for HOFs that produce one output row per input row by post-processing the + * lambda's elementwise result. Subclasses implement transformListColumnView and inherit the + * standard columnarEval that drives the explode -> lambda eval -> rewrap chain. + */ +trait GpuArrayElementWiseTransform extends GpuArrayTransformBase { /* * Post-process the column view of the array after applying the function parameter. * @param lambdaTransformedCV the results of the lambda expression running @@ -303,7 +317,7 @@ case class GpuArrayTransform( argument: Expression, function: Expression, isBound: Boolean = false, - boundIntermediate: Seq[GpuExpression] = Seq.empty) extends GpuArrayTransformBase { + boundIntermediate: Seq[GpuExpression] = Seq.empty) extends GpuArrayElementWiseTransform { override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) @@ -326,7 +340,7 @@ case class GpuArrayExists( function: Expression, followThreeValuedLogic: Boolean, isBound: Boolean = false, - boundIntermediate: Seq[GpuExpression] = Seq.empty) extends GpuArrayTransformBase { + boundIntermediate: Seq[GpuExpression] = Seq.empty) extends GpuArrayElementWiseTransform { override def dataType: DataType = BooleanType @@ -424,7 +438,7 @@ case class GpuArrayFilter( argument: Expression, function: Expression, isBound: Boolean = false, - boundIntermediate: Seq[GpuExpression] = Seq.empty) extends GpuArrayTransformBase { + boundIntermediate: Seq[GpuExpression] = Seq.empty) extends GpuArrayElementWiseTransform { override def dataType: DataType = argument.dataType @@ -907,10 +921,15 @@ sealed trait AggOp { def nullPolicy: cudf.NullPolicy /** Identity scalar typed to match `t` so ifElse / binaryOp don't hit width mismatch. */ def identityScalar(t: DataType): cudf.Scalar - /** `reduced OP zero`, typed to outDType, with Spark-matching null propagation. */ + /** + * `reduced OP zero`, typed to outDType, with Spark-matching null propagation. `zero` is + * a `BinaryOperable` so callers can pass either a `cudf.Scalar` (when the Spark-side + * `zero` is a Literal — saves one full-row column allocation per batch) or a `ColumnView` + * (when `zero` references an outer column). + */ def combineWithZero( reduced: cudf.ColumnVector, - zero: cudf.ColumnView, + zero: cudf.BinaryOperable, outDType: DType): cudf.ColumnVector /** Return (left, right) if the body is this op's Catalyst shape. */ def matchBinary(body: Expression): Option[(Expression, Expression)] @@ -933,7 +952,7 @@ case object SumOp extends AggOp { case d: DecimalType => GpuScalar.from(0, d) case other => throw new IllegalStateException(s"SUM identity not defined for $other") } - def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.add(z, out) + def combineWithZero(r: cudf.ColumnVector, z: cudf.BinaryOperable, out: DType) = r.add(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case a: Add => Some((a.left, a.right)) case _ => None @@ -955,7 +974,7 @@ case object ProductOp extends AggOp { case DoubleType => cudf.Scalar.fromDouble(1.0) case other => throw new IllegalStateException(s"PRODUCT identity not defined for $other") } - def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.mul(z, out) + def combineWithZero(r: cudf.ColumnVector, z: cudf.BinaryOperable, out: DType) = r.mul(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case m: Multiply => Some((m.left, m.right)) case _ => None @@ -981,8 +1000,8 @@ case object ProductOp extends AggOp { sealed trait ExtremumOp extends AggOp { val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.EXCLUDE def binaryOp: cudf.BinaryOp - def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType): cudf.ColumnVector = - r.binaryOp(binaryOp, z, out) + def combineWithZero(r: cudf.ColumnVector, z: cudf.BinaryOperable, out: DType) + : cudf.ColumnVector = r.binaryOp(binaryOp, z, out) def supportsType(t: DataType): Boolean = t match { case ByteType | ShortType | IntegerType | LongType => true case _ => false @@ -1029,7 +1048,7 @@ case object AllOp extends AggOp { // INCLUDE: matches Spark's 3VL for AND (null AND true = null, null AND false = false). val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE def identityScalar(t: DataType): cudf.Scalar = cudf.Scalar.fromBool(true) - def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.and(z, out) + def combineWithZero(r: cudf.ColumnVector, z: cudf.BinaryOperable, out: DType) = r.and(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case a: And => Some((a.left, a.right)) case _ => None @@ -1042,7 +1061,7 @@ case object AnyOp extends AggOp { def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.any() val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE def identityScalar(t: DataType): cudf.Scalar = cudf.Scalar.fromBool(false) - def combineWithZero(r: cudf.ColumnVector, z: cudf.ColumnView, out: DType) = r.or(z, out) + def combineWithZero(r: cudf.ColumnVector, z: cudf.BinaryOperable, out: DType) = r.or(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case o: Or => Some((o.left, o.right)) case _ => None @@ -1055,17 +1074,17 @@ case object AnyOp extends AggOp { * Result of successfully matching a Spark ArrayAggregate's merge lambda against a * registered AggOp. * - * @param op the chosen aggregation operator - * @param g the Catalyst sub-expression corresponding to `g(x)` in the - * `(acc, x) -> op(acc, g(x))` rewrite — stored directly (rather than - * a child index) so convertToGpuImpl locates it by expression - * identity instead of relying on a meta-children ordering invariant - * @param accVarExprId the accumulator NamedLambdaVariable's exprId - * @param elemVar the element NamedLambdaVariable (used to build the g lambda) + * @param op the chosen aggregation operator + * @param gIsLeftOfMergeBody whether `g(x)` is the left child of the merge body's binary op + * (true) or the right child (false). convertToGpuImpl uses this + * index to pick the matching meta-child without re-walking the + * Catalyst tree + * @param accVarExprId the accumulator NamedLambdaVariable's exprId + * @param elemVar the element NamedLambdaVariable (used to build the g lambda) */ case class ArrayAggregateDecomposition( op: AggOp, - g: Expression, + gIsLeftOfMergeBody: Boolean, accVarExprId: ExprId, elemVar: NamedLambdaVariable) @@ -1073,34 +1092,72 @@ case class ArrayAggregateDecomposition( /** * Decomposes a Spark ArrayAggregate's merge lambda of shape `(acc, x) -> op(acc, g(x))` * where `op` is one of the registered AggOps and the finish lambda is identity. + * + * decompose owns every reason ArrayAggregate cannot run on the GPU — shape, type, and + * nullability — so the meta layer is just a single Either match. */ object ArrayAggregateDecomposer { /** All ops the decomposer will try, in order. */ val allOps: Seq[AggOp] = Seq(SumOp, ProductOp, MaxOp, MinOp, AllOp, AnyOp) - def decompose(merge: Expression, finish: Expression): Option[ArrayAggregateDecomposition] = { + def decompose( + merge: Expression, + finish: Expression, + argType: DataType, + zeroType: DataType): Either[String, ArrayAggregateDecomposition] = { val mergeLambda = merge match { case lf: LambdaFunction => lf - case _ => return None + case _ => return Left("merge expression is not a LambdaFunction") } val (accVar, elemVar) = mergeLambda.arguments match { case Seq(a: NamedLambdaVariable, e: NamedLambdaVariable) => (a, e) - case _ => return None + case _ => return Left("merge lambda must take exactly 2 NamedLambdaVariable arguments") + } + if (!isFinishIdentity(finish)) { + return Left("finish lambda is not an identity (only `acc -> acc` is supported)") } - if (!isFinishIdentity(finish)) return None val body = mergeLambda.function val accId = accVar.exprId - allOps.view.flatMap { op => + val matched = allOps.view.flatMap { op => op.matchBinary(body).flatMap { case (l, r) => if (isAccRef(l, accId) && !containsAccRef(r, accId)) { - Some(ArrayAggregateDecomposition(op, r, accId, elemVar)) + Some((op, r, /* gIsLeftOfMergeBody = */ false)) } else if (isAccRef(r, accId) && !containsAccRef(l, accId)) { - Some(ArrayAggregateDecomposition(op, l, accId, elemVar)) + Some((op, l, /* gIsLeftOfMergeBody = */ true)) } else None } }.headOption + + val (op, g, gIsLeft) = matched.getOrElse { + return Left("merge body does not match (acc, x) -> op(acc, g(x)) for any registered " + + "op (" + allOps.map(_.name).mkString(", ") + ")") + } + + if (!op.supportsType(zeroType)) { + return Left(s"${op.name} is not supported on GPU for type $zeroType") + } + // g's output type must equal the accumulator/zero type so the segmented reduce output + // matches the Spark-expected result type directly. + if (!DataType.equalsStructurally(g.dataType, zeroType, ignoreNullability = true)) { + return Left(s"g(x) output type (${g.dataType}) does not match accumulator/zero type " + + s"($zeroType)") + } + // cuDF's segmented ALL/ANY with INCLUDE nulls doesn't match Spark's AND/OR 3VL + // (specifically: `false AND null = false` short-circuit, or `true OR null = true`, are + // both missed by cuDF which returns null whenever any null is present). Fall back to + // CPU when the input array can contain nulls. + if (op == AllOp || op == AnyOp) { + argType match { + case ArrayType(_, true) => + return Left(s"${op.name} is not supported on GPU for arrays that may contain " + + "nulls; cuDF's INCLUDE-nulls semantics don't match Spark's AND/OR 3VL") + case _ => + } + } + + Right(ArrayAggregateDecomposition(op, gIsLeft, accId, elemVar)) } private def isFinishIdentity(finish: Expression): Boolean = finish match { @@ -1159,13 +1216,6 @@ case class GpuArrayAggregate( GpuArrayAggregate(boundArg, boundZero, boundFunc, op, isBound = true, boundInter) } - override protected def transformListColumnView( - lambdaTransformedCV: cudf.ColumnView, - arg: cudf.ColumnView): GpuColumnVector = { - throw new IllegalStateException( - "GpuArrayAggregate overrides columnarEval; transformListColumnView is unused") - } - /** * Mask of rows where the reduce result must be replaced with the op's identity. * @@ -1228,10 +1278,19 @@ case class GpuArrayAggregate( } } - // Step 3: combine with zero. + // Step 3: combine with zero. When `zero` is a Literal (the common 4-arg + // `aggregate(arr, 0, ...)` shape) skip the per-batch column broadcast and pass a + // cudf.Scalar instead — `add/mul/and/or/binaryOp` all accept BinaryOperable. val combined: cudf.ColumnVector = withResource(adjusted) { adjusted => - withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => - op.combineWithZero(adjusted, zeroCv.getBase, outDType) + zero match { + case lit: GpuLiteral => + withResource(GpuScalar.from(lit.value, lit.dataType)) { zeroScalar => + op.combineWithZero(adjusted, zeroScalar, outDType) + } + case _ => + withResource(zero.asInstanceOf[GpuExpression].columnarEval(batch)) { zeroCv => + op.combineWithZero(adjusted, zeroCv.getBase, outDType) + } } } @@ -1265,44 +1324,11 @@ class GpuArrayAggregateMeta( private var decomposition: Option[ArrayAggregateDecomposition] = None override def tagExprForGpu(): Unit = { - val d = ArrayAggregateDecomposer.decompose(expr.merge, expr.finish) - if (d.isEmpty) { - willNotWorkOnGpu( - "ArrayAggregate only supports lambdas of the form (acc, x) -> op(acc, g(x)) " + - "with an identity finish lambda, where op is one of " + - ArrayAggregateDecomposer.allOps.map(_.name).mkString(", ") + ".") - return - } - val decomp = d.get - if (!decomp.op.supportsType(expr.zero.dataType)) { - willNotWorkOnGpu( - s"${decomp.op.name} is not supported on GPU for type ${expr.zero.dataType}") - return - } - // g's output type must equal the accumulator/zero type so the segmented reduce output - // matches the Spark-expected result type directly. - if (!DataType.equalsStructurally( - decomp.g.dataType, expr.zero.dataType, ignoreNullability = true)) { - willNotWorkOnGpu( - s"g(x) output type (${decomp.g.dataType}) does not match accumulator/zero type " + - s"(${expr.zero.dataType})") - return - } - // cuDF's segmented ALL/ANY with INCLUDE nulls doesn't match Spark's AND/OR 3VL - // (specifically: `false AND null = false` short-circuit, or `true OR null = true`, are - // both missed by cuDF which returns null whenever any null is present). Fall back to - // CPU when the input array can contain nulls. - if (decomp.op == AllOp || decomp.op == AnyOp) { - expr.argument.dataType match { - case ArrayType(_, containsNull) if containsNull => - willNotWorkOnGpu( - s"${decomp.op.name} is not supported on GPU for arrays that may contain nulls; " + - "cuDF's INCLUDE-nulls semantics don't match Spark's AND/OR 3VL") - return - case _ => - } + ArrayAggregateDecomposer.decompose( + expr.merge, expr.finish, expr.argument.dataType, expr.zero.dataType) match { + case Left(reason) => willNotWorkOnGpu(reason) + case Right(d) => decomposition = Some(d) } - decomposition = d } override def convertToGpuImpl(): GpuExpression = { @@ -1311,15 +1337,13 @@ class GpuArrayAggregateMeta( val argGpu = childExprs.head.convertToGpu() val zeroGpu = childExprs(1).convertToGpu() - // childExprs(2) is the merge lambda meta; its first child is the op body meta. Find - // the sub-meta whose wrapped CPU expression matches the g we recorded during - // decomposition, so we don't rely on meta-children ordering lining up with Catalyst's - // [left, right] convention. + // childExprs(2) is the merge lambda meta; its first child is the op body meta. The + // decomposer already recorded which side `g(x)` is on (gIsLeftOfMergeBody), so pick by + // index. Meta children mirror Catalyst children order, which is stable for every op + // matchBinary accepts (Add/Multiply/And/Or are commutative-shape, Greatest/Least take + // a Seq[Expression] in declaration order). val bodyMeta = childExprs(2).childExprs.head - val gMeta = bodyMeta.childExprs.find { - _.wrapped.asInstanceOf[Expression].fastEquals(d.g) - }.getOrElse(throw new IllegalStateException( - s"could not locate g sub-expression ${d.g} under merge body meta")) + val gMeta = if (d.gIsLeftOfMergeBody) bodyMeta.childExprs.head else bodyMeta.childExprs(1) val gGpu = gMeta.convertToGpu() val elemVarGpu = GpuNamedLambdaVariable( d.elemVar.name, d.elemVar.dataType, d.elemVar.nullable, d.elemVar.exprId) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala index ad6339619ab..2c96f2c4b44 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala @@ -19,7 +19,8 @@ package com.nvidia.spark.rapids import org.apache.spark.sql.catalyst.expressions.{Add, And, Cast, Divide, Expression, Greatest, LambdaFunction, Least, Literal, Multiply, NamedExpression, NamedLambdaVariable, Or, Subtract} -import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, LongType} +import org.apache.spark.sql.types.{ArrayType, BooleanType, DataType, DoubleType, IntegerType, + LongType} // Extends GpuUnitTests so SQLConf.get is available for the default evalMode / failOnError // parameter on Add/Subtract/Multiply/Divide (the field name differs across Spark versions; @@ -39,78 +40,91 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { private def identityFinish(acc: NamedLambdaVariable): LambdaFunction = LambdaFunction(acc, Seq(acc)) + /** Wrap zeroType in an ArrayType(_, containsNull = false) for the typical happy path. */ + private def arrTy(zeroType: DataType): ArrayType = ArrayType(zeroType, containsNull = false) + private def assertDecomposes( body: Expression, acc: NamedLambdaVariable, x: NamedLambdaVariable, expectedOp: AggOp, - expectedG: Expression): ArrayAggregateDecomposition = { - val d = decompose(merge(body, acc, x), identityFinish(acc)) - assert(d.isDefined, s"expected decomposition for body=$body") - assert(d.get.op == expectedOp) - assert(d.get.g.fastEquals(expectedG), s"expected g=$expectedG, got ${d.get.g}") - assert(d.get.accVarExprId == acc.exprId) - assert(d.get.elemVar.exprId == x.exprId) - d.get + expectedGIsLeft: Boolean, + zeroType: DataType = IntegerType, + argType: Option[DataType] = None): ArrayAggregateDecomposition = { + val d = decompose(merge(body, acc, x), identityFinish(acc), + argType.getOrElse(arrTy(zeroType)), zeroType) + val r = d.getOrElse(fail(s"expected decomposition for body=$body, got Left: $d")) + assert(r.op == expectedOp) + assert(r.gIsLeftOfMergeBody == expectedGIsLeft, + s"expected gIsLeftOfMergeBody=$expectedGIsLeft, got ${r.gIsLeftOfMergeBody}") + assert(r.accVarExprId == acc.exprId) + assert(r.elemVar.exprId == x.exprId) + r } private def assertRejects( mergeBody: LambdaFunction, finish: Expression, - reason: String): Unit = { - assert(decompose(mergeBody, finish).isEmpty, reason) + reason: String, + zeroType: DataType = IntegerType, + argType: Option[DataType] = None): String = { + val d = decompose(mergeBody, finish, argType.getOrElse(arrTy(zeroType)), zeroType) + assert(d.isLeft, s"$reason — expected Left but got: $d") + d.swap.getOrElse(fail("unreachable")) } - test("Add(acc, x) -> SUM, g = x") { + test("Add(acc, x) -> SUM, g on the right") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Add(acc, x), acc, x, SumOp, x) + assertDecomposes(Add(acc, x), acc, x, SumOp, expectedGIsLeft = false) } - test("Add(x, acc) (commuted) -> SUM, g = x") { + test("Add(x, acc) (commuted) -> SUM, g on the left") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Add(x, acc), acc, x, SumOp, x) + assertDecomposes(Add(x, acc), acc, x, SumOp, expectedGIsLeft = true) } - test("Multiply(acc, x) -> PRODUCT, g = x") { + test("Multiply(acc, x) -> PRODUCT") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Multiply(acc, x), acc, x, ProductOp, x) + assertDecomposes(Multiply(acc, x), acc, x, ProductOp, expectedGIsLeft = false) } - test("Greatest(acc, x) -> MAX, g = x") { + test("Greatest(acc, x) -> MAX") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Greatest(Seq(acc, x)), acc, x, MaxOp, x) + assertDecomposes(Greatest(Seq(acc, x)), acc, x, MaxOp, expectedGIsLeft = false) } - test("Least(acc, x) -> MIN, g = x") { + test("Least(acc, x) -> MIN") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Least(Seq(acc, x)), acc, x, MinOp, x) + assertDecomposes(Least(Seq(acc, x)), acc, x, MinOp, expectedGIsLeft = false) } - test("And(acc, x) -> ALL, g = x") { + test("And(acc, x) -> ALL") { val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) - assertDecomposes(And(acc, x), acc, x, AllOp, x) + assertDecomposes(And(acc, x), acc, x, AllOp, expectedGIsLeft = false, + zeroType = BooleanType) } - test("Or(acc, x) -> ANY, g = x") { + test("Or(acc, x) -> ANY") { val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) - assertDecomposes(Or(acc, x), acc, x, AnyOp, x) + assertDecomposes(Or(acc, x), acc, x, AnyOp, expectedGIsLeft = false, + zeroType = BooleanType) } - test("Complex g(x) with no acc ref is captured verbatim") { + test("Complex g(x) with no acc ref decomposes (g on the right)") { val acc = lv("acc", LongType); val x = lv("x", IntegerType) val g = Cast(Add(Multiply(x, Literal(2)), Literal(1)), LongType) - assertDecomposes(Add(acc, g), acc, x, SumOp, g) + assertDecomposes(Add(acc, g), acc, x, SumOp, expectedGIsLeft = false, zeroType = LongType) } test("Cast wrapping the acc side is unwrapped (single layer)") { val acc = lv("acc", LongType); val x = lv("x", IntegerType) - assertDecomposes(Add(Cast(acc, IntegerType), x), acc, x, SumOp, x) + assertDecomposes(Add(Cast(acc, IntegerType), x), acc, x, SumOp, expectedGIsLeft = false) } test("Cast wrapping the acc side is unwrapped (chained)") { val acc = lv("acc"); val x = lv("x") val doubleCastAcc = Cast(Cast(acc, LongType), IntegerType) - assertDecomposes(Add(doubleCastAcc, x), acc, x, SumOp, x) + assertDecomposes(Add(doubleCastAcc, x), acc, x, SumOp, expectedGIsLeft = false) } test("Subtract is not an associative op we recognize") { @@ -175,11 +189,43 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { test("merge that isn't a LambdaFunction at all is rejected") { val acc = lv("acc") - assert(decompose(Add(Literal(1), Literal(2)), identityFinish(acc)).isEmpty) + assert(decompose(Add(Literal(1), Literal(2)), identityFinish(acc), + arrTy(IntegerType), IntegerType).isLeft) } test("finish that isn't a LambdaFunction is rejected") { val acc = lv("acc"); val x = lv("x") - assert(decompose(merge(Add(acc, x), acc, x), Literal(0)).isEmpty) + assert(decompose(merge(Add(acc, x), acc, x), Literal(0), + arrTy(IntegerType), IntegerType).isLeft) + } + + // The decomposer now owns the "is this shape ever GPU-able" decision, so it must also + // reject unsupported types and AllOp/AnyOp on null-bearing arrays. + + test("MaxOp on Double is rejected (NaN propagation differs from cuDF)") { + val acc = lv("acc", DoubleType); val x = lv("x", DoubleType) + val msg = assertRejects(merge(Greatest(Seq(acc, x)), acc, x), identityFinish(acc), + "MAX should fall back on Double", + zeroType = DoubleType) + assert(msg.contains("MAX"), s"expected MAX-related error, got: $msg") + } + + test("ALL on array with containsNull rejects") { + val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) + val msg = assertRejects(merge(And(acc, x), acc, x), identityFinish(acc), + "ALL on null-bearing array should fall back", + zeroType = BooleanType, + argType = Some(ArrayType(BooleanType, containsNull = true))) + assert(msg.contains("ALL"), s"expected ALL-related error, got: $msg") + } + + test("g type mismatch with zero type rejects") { + val acc = lv("acc", LongType); val x = lv("x", IntegerType) + // body sums a non-cast Int element into a Long acc — g.dataType=Int doesn't match + // zeroType=Long, so this must fall back even though the shape is otherwise OK. + val msg = assertRejects(merge(Add(acc, x), acc, x), identityFinish(acc), + "g type mismatch should fall back", + zeroType = LongType) + assert(msg.contains("does not match"), s"expected type-mismatch error, got: $msg") } } From 84c4730812f4c9078d0ad5dccbc1f872a7bbf555 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 30 Apr 2026 14:32:26 +0800 Subject: [PATCH 11/12] bug fix Signed-off-by: Haoyang Li --- docs/supported_ops.md | 4 +- .../python/higher_order_functions_test.py | 74 +++--- .../nvidia/spark/rapids/GpuOverrides.scala | 9 +- .../spark/rapids/higherOrderFunctions.scala | 218 +++++++++++++++--- .../ArrayAggregateDecomposerSuite.scala | 115 +++++++-- 5 files changed, 334 insertions(+), 86 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index d6cc3be9a2d..4c002359a0c 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -2359,7 +2359,7 @@ are limited. ArrayAggregate `aggregate` -Aggregate elements in an array using an accumulator function and finishing transformation. Currently only lambdas of the form (acc, x) -> acc + g(x) with an identity finish are executed on the GPU; other shapes fall back to CPU. +Aggregate elements in an array using an accumulator function and finishing transformation. Currently only lambdas of the form (acc, x) -> op(acc, g(x)) with an identity finish are executed on the GPU, where op is one of SUM/PRODUCT/MAX/MIN/ALL/ANY. If/CaseWhen branches are accepted as long as each branch is itself op-of-acc (or bare acc) with op consistent across branches; other shapes fall back to CPU. None project zero @@ -2469,7 +2469,7 @@ are limited. -PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types BINARY, CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH
+PS
UTC is only supported TZ for child TIMESTAMP;
unsupported child types CALENDAR, ARRAY, MAP, UDT, DAYTIME, YEARMONTH
diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 6acd954194d..5c411566306 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -34,15 +34,6 @@ def do_project(spark): assert_gpu_and_cpu_are_equal_collect(do_project, conf=confs) -# --- ArrayAggregate tests --- -# -# The decomposer accepts lambdas of the form `(acc, x) -> op(acc, g(x))` with an identity -# finish, where `op` is one of SUM/PRODUCT/MAX/MIN/ALL/ANY. Other shapes fall back to CPU. - - -# Happy path for each supported numeric op. Product uses a narrow range to keep the test -# output numerically tame (GPU and CPU both wrap consistently, but small numbers make the -# test easier to read when debugging a failure). @pytest.mark.parametrize('lambda_sql, init_sql, gen_max', [ ('(acc, x) -> acc + CAST(x as BIGINT)', '0L', 100), ('(acc, x) -> acc * CAST(x as BIGINT)', '1L', 3), @@ -58,9 +49,6 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# Same ops exercised on the native element type (no Cast in the lambda body), so the -# identityScalar / combineWithZero paths for Int / Long are hit directly. Covers the -# INCLUDE-policy null-element propagation for SUM on a nullable element type too. @pytest.mark.parametrize('gen, lambda_sql, init_sql', [ (IntegerGen(min_val=-100, max_val=100), '(acc, x) -> acc + x', '0'), (LongGen(min_val=-100, max_val=100), '(acc, x) -> acc + x', '0L'), @@ -77,10 +65,8 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# Happy path for the boolean ops. Elements must be non-null because cuDF's segmented ALL/ -# ANY with INCLUDE nulls don't match Spark's AND/OR 3VL for mixed null+bool (specifically, -# `false AND null = false` short-circuit; `true OR null = true`). The tag-time guard falls -# back to CPU when the element type is nullable, so here we use a non-nullable BooleanGen. +# Elements are non-null because the tag-time guard falls back to CPU when the element type +# is nullable. @pytest.mark.parametrize('lambda_sql, init_sql', [ ('(acc, x) -> acc AND x', 'true'), ('(acc, x) -> acc OR x', 'false'), @@ -94,8 +80,6 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# When array elements may contain nulls, ALL/ANY must fall back to CPU (cuDF's INCLUDE- -# nulls semantics don't match Spark's AND/OR 3VL). @pytest.mark.parametrize('lambda_sql, init_sql', [ ('(acc, x) -> acc AND x', 'true'), ('(acc, x) -> acc OR x', 'false'), @@ -109,7 +93,6 @@ def test_array_aggregate_boolean_ops_nullable_elements_fallback(lambda_sql, init 'ArrayAggregate') -# Count-if pattern: aggregate(array, 0, (acc, x) -> acc + CASE WHEN ... THEN 1 ELSE 0 END). @disable_ansi_mode def test_array_aggregate_count_if_int(): assert_gpu_and_cpu_are_equal_collect( @@ -118,7 +101,29 @@ def test_array_aggregate_count_if_int(): 'aggregate(a, 0L, (acc, x) -> acc + CAST(CASE WHEN x IS NULL THEN 1 ELSE 0 END as BIGINT)) as null_cnt')) -# Composed pattern: filter + aggregate with split / GetArrayItem / IN inside the lambda. +# `if(cond, acc + t, acc)` shape — branches lifted via op identity. Same count-if +# pattern as above but written naturally instead of using `CASE WHEN ... THEN 1 ELSE 0`. +@disable_ansi_mode +def test_array_aggregate_if_count(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=15)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> if(x > 0, acc + 1L, acc)) as pos_cnt', + 'aggregate(a, 0L, (acc, x) -> if(x is null, acc, acc + 1L)) as nonnull_cnt')) + + +# CaseWhen with several acc+t branches and a bare-acc else. +@disable_ansi_mode +def test_array_aggregate_casewhen_multi_branch(): + assert_gpu_and_cpu_are_equal_collect( + lambda spark: unary_op_df(spark, ArrayGen(int_gen, max_length=15)).selectExpr( + '''aggregate(a, 0L, + (acc, x) -> CASE + WHEN x > 100 THEN acc + 10L + WHEN x > 0 THEN acc + 1L + ELSE acc + END) as weighted_cnt''')) + + @disable_ansi_mode def test_array_aggregate_with_filter_and_split(): field_gen = StringGen('[a-z]{2}') @@ -138,7 +143,6 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# Non-zero init: result should include the init. @disable_ansi_mode def test_array_aggregate_non_zero_init(): assert_gpu_and_cpu_are_equal_collect( @@ -146,7 +150,6 @@ def test_array_aggregate_non_zero_init(): 'aggregate(a, 100L, (acc, x) -> acc + CAST(x as BIGINT)) as sum_with_init')) -# null array -> null, empty array -> finish(init) = init. @disable_ansi_mode def test_array_aggregate_null_array(): assert_gpu_and_cpu_are_equal_collect( @@ -164,7 +167,6 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# Lambda body references an outer attribute — exercises boundIntermediate plumbing. @disable_ansi_mode def test_array_aggregate_lambda_refs_outer_column(): def do_it(spark): @@ -173,7 +175,6 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# zero is an outer column, not a literal. @disable_ansi_mode def test_array_aggregate_zero_is_outer_column(): def do_it(spark): @@ -182,7 +183,6 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# array: accumulate over a struct field. @disable_ansi_mode def test_array_aggregate_over_struct_field(): def do_it(spark): @@ -192,7 +192,16 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# Deeper g body without acc references (x * 2 + 1). +@disable_ansi_mode +def test_array_aggregate_over_binary(): + # GpuLength only accepts STRING, so we hex(binary) → string first to keep the + # whole lambda on the GPU. Result: 2 × byte count of each element, summed. + def do_it(spark): + return unary_op_df(spark, ArrayGen(BinaryGen(max_length=10), max_length=8)).selectExpr( + 'aggregate(a, 0L, (acc, x) -> acc + CAST(length(hex(x)) as BIGINT)) as total_hex_len') + assert_gpu_and_cpu_are_equal_collect(do_it) + + @disable_ansi_mode def test_array_aggregate_deeper_g_body(): assert_gpu_and_cpu_are_equal_collect( @@ -200,7 +209,7 @@ def test_array_aggregate_deeper_g_body(): 'aggregate(a, 0L, (acc, x) -> acc + CAST(x * 2 + 1 as BIGINT)) as sum_poly')) -# Long-overflow wrap-around matches between Spark SUM and cudf SUM in non-ANSI mode. +# Long overflow wraps in non-ANSI mode on both Spark SUM and cuDF SUM. @disable_ansi_mode def test_array_aggregate_long_overflow_wraps(): def do_it(spark): @@ -210,8 +219,6 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# Decimal SUM: zero must be widened to DECIMAL(38,2) (Spark's cap) with the element Cast to -# match so that merge.dataType == zero.dataType (Spark's checkInputDataTypes). @disable_ansi_mode def test_array_aggregate_decimal_sum(): decimal_gen = DecimalGen(precision=10, scale=2) @@ -222,9 +229,6 @@ def do_it(spark): assert_gpu_and_cpu_are_equal_collect(do_it) -# Shapes the decomposer rejects must fall back to CPU. Covered: non-associative op -# (Subtract, Divide), variadic op with wrong arity (Greatest with 3 children), and a lambda -# whose g sub-expression references the accumulator. @pytest.mark.parametrize('lambda_sql, init_sql', [ ('(acc, x) -> acc - CAST(x as BIGINT)', '0L'), ('(acc, x) -> CAST(acc / CAST(x + 1 as BIGINT) as BIGINT)', '1L'), @@ -240,8 +244,6 @@ def test_array_aggregate_fallback_shapes(lambda_sql, init_sql): 'ArrayAggregate') -# Non-identity finish is kept as its own test because its SQL shape (4-arg aggregate with -# a separate finish lambda) differs from the merge-only fallbacks above. @disable_ansi_mode @allow_non_gpu('ProjectExec') def test_array_aggregate_non_identity_finish_falls_back(): @@ -251,10 +253,6 @@ def test_array_aggregate_non_identity_finish_falls_back(): 'ArrayAggregate') -# MAX / MIN on float/double arrays must fall back: cuDF's segmented max/min follow IEEE 754 -# where NaN is absorbed (`fmax(NaN, x) = x`), while Spark's `Greatest`/`Least` propagate NaN -# via `Double.compare`. Rather than paper over this for now we restrict ExtremumOp to -# integral types and fall back on float/double. @pytest.mark.parametrize('lambda_sql, init_sql', [ ('(acc, x) -> greatest(acc, x)', 'CAST("-Infinity" as DOUBLE)'), ('(acc, x) -> least(acc, x)', 'CAST("Infinity" as DOUBLE)'), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 5c6e40fc457..3f991450081 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2964,15 +2964,18 @@ object GpuOverrides extends Logging { }), expr[ArrayAggregate]( "Aggregate elements in an array using an accumulator function and finishing " + - "transformation. Currently only lambdas of the form (acc, x) -> acc + g(x) with an " + - "identity finish are executed on the GPU; other shapes fall back to CPU.", + "transformation. Currently only lambdas of the form (acc, x) -> op(acc, g(x)) with " + + "an identity finish are executed on the GPU, where op is one of SUM/PRODUCT/MAX/" + + "MIN/ALL/ANY. If/CaseWhen branches are accepted as long as each branch is itself " + + "op-of-acc (or bare acc) with op consistent across branches; other shapes fall " + + "back to CPU.", ExprChecks.projectOnly( TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, TypeSig.all, Seq( ParamCheck("argument", TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.NULL + - TypeSig.STRUCT), + TypeSig.BINARY + TypeSig.STRUCT), TypeSig.ARRAY.nested(TypeSig.all)), ParamCheck("zero", TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index 5f61f60b031..00810155113 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -26,9 +26,9 @@ import com.nvidia.spark.rapids.jni.GpuMapZipWithUtils import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.catalyst.expressions.{Add, And, ArrayAggregate, Attribute, AttributeReference, AttributeSeq, Cast, Expression, ExprId, Greatest, LambdaFunction, Least, Multiply, NamedExpression, NamedLambdaVariable, Or} +import org.apache.spark.sql.catalyst.expressions.{Add, And, ArrayAggregate, Attribute, AttributeReference, AttributeSeq, CaseWhen, Cast, Expression, ExprId, Greatest, If, LambdaFunction, Least, Literal, Multiply, NamedExpression, NamedLambdaVariable, Or} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, NumericType, ShortType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, NumericType, ShortType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch /** @@ -921,6 +921,12 @@ sealed trait AggOp { def nullPolicy: cudf.NullPolicy /** Identity scalar typed to match `t` so ifElse / binaryOp don't hit width mismatch. */ def identityScalar(t: DataType): cudf.Scalar + /** + * Catalyst-side identity, used by the decomposer to plug into `If`/`CaseWhen` branches + * that are bare `acc` (treated as `op(acc, identity)` so the branch can be lifted out). + * Must satisfy `op(acc, identityLiteral(t)) == acc` for any acc of type `t`. + */ + def identityLiteral(t: DataType): Literal /** * `reduced OP zero`, typed to outDType, with Spark-matching null propagation. `zero` is * a `BinaryOperable` so callers can pass either a `cudf.Scalar` (when the Spark-side @@ -952,6 +958,18 @@ case object SumOp extends AggOp { case d: DecimalType => GpuScalar.from(0, d) case other => throw new IllegalStateException(s"SUM identity not defined for $other") } + // Each arm builds the value at exactly the right Scala type so Spark's Literal + // type-compatibility check (validateLiteralValue) doesn't reject e.g. Int into LongType. + def identityLiteral(t: DataType): Literal = t match { + case ByteType => Literal(0.toByte, ByteType) + case ShortType => Literal(0.toShort, ShortType) + case IntegerType => Literal(0, IntegerType) + case LongType => Literal(0L, LongType) + case FloatType => Literal(0.0f, FloatType) + case DoubleType => Literal(0.0, DoubleType) + case d: DecimalType => Literal(Decimal(0L, d.precision, d.scale), d) + case other => throw new IllegalStateException(s"SUM identity literal not defined for $other") + } def combineWithZero(r: cudf.ColumnVector, z: cudf.BinaryOperable, out: DType) = r.add(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case a: Add => Some((a.left, a.right)) @@ -974,6 +992,16 @@ case object ProductOp extends AggOp { case DoubleType => cudf.Scalar.fromDouble(1.0) case other => throw new IllegalStateException(s"PRODUCT identity not defined for $other") } + def identityLiteral(t: DataType): Literal = t match { + case ByteType => Literal(1.toByte, ByteType) + case ShortType => Literal(1.toShort, ShortType) + case IntegerType => Literal(1, IntegerType) + case LongType => Literal(1L, LongType) + case FloatType => Literal(1.0f, FloatType) + case DoubleType => Literal(1.0, DoubleType) + case other => throw new IllegalStateException( + s"PRODUCT identity literal not defined for $other") + } def combineWithZero(r: cudf.ColumnVector, z: cudf.BinaryOperable, out: DType) = r.mul(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case m: Multiply => Some((m.left, m.right)) @@ -1019,6 +1047,13 @@ case object MaxOp extends ExtremumOp { case LongType => cudf.Scalar.fromLong(Long.MinValue) case other => throw new IllegalStateException(s"MAX identity not defined for $other") } + def identityLiteral(t: DataType): Literal = t match { + case ByteType => Literal(Byte.MinValue, ByteType) + case ShortType => Literal(Short.MinValue, ShortType) + case IntegerType => Literal(Int.MinValue, IntegerType) + case LongType => Literal(Long.MinValue, LongType) + case other => throw new IllegalStateException(s"MAX identity literal not defined for $other") + } def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case g: Greatest if g.children.size == 2 => Some((g.children.head, g.children(1))) case _ => None @@ -1036,6 +1071,13 @@ case object MinOp extends ExtremumOp { case LongType => cudf.Scalar.fromLong(Long.MaxValue) case other => throw new IllegalStateException(s"MIN identity not defined for $other") } + def identityLiteral(t: DataType): Literal = t match { + case ByteType => Literal(Byte.MaxValue, ByteType) + case ShortType => Literal(Short.MaxValue, ShortType) + case IntegerType => Literal(Int.MaxValue, IntegerType) + case LongType => Literal(Long.MaxValue, LongType) + case other => throw new IllegalStateException(s"MIN identity literal not defined for $other") + } def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case l: Least if l.children.size == 2 => Some((l.children.head, l.children(1))) case _ => None @@ -1048,6 +1090,7 @@ case object AllOp extends AggOp { // INCLUDE: matches Spark's 3VL for AND (null AND true = null, null AND false = false). val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE def identityScalar(t: DataType): cudf.Scalar = cudf.Scalar.fromBool(true) + def identityLiteral(t: DataType): Literal = Literal(true, BooleanType) def combineWithZero(r: cudf.ColumnVector, z: cudf.BinaryOperable, out: DType) = r.and(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case a: And => Some((a.left, a.right)) @@ -1061,6 +1104,7 @@ case object AnyOp extends AggOp { def cudfAgg: cudf.SegmentedReductionAggregation = cudf.SegmentedReductionAggregation.any() val nullPolicy: cudf.NullPolicy = cudf.NullPolicy.INCLUDE def identityScalar(t: DataType): cudf.Scalar = cudf.Scalar.fromBool(false) + def identityLiteral(t: DataType): Literal = Literal(false, BooleanType) def combineWithZero(r: cudf.ColumnVector, z: cudf.BinaryOperable, out: DType) = r.or(z, out) def matchBinary(e: Expression): Option[(Expression, Expression)] = e match { case o: Or => Some((o.left, o.right)) @@ -1074,17 +1118,17 @@ case object AnyOp extends AggOp { * Result of successfully matching a Spark ArrayAggregate's merge lambda against a * registered AggOp. * - * @param op the chosen aggregation operator - * @param gIsLeftOfMergeBody whether `g(x)` is the left child of the merge body's binary op - * (true) or the right child (false). convertToGpuImpl uses this - * index to pick the matching meta-child without re-walking the - * Catalyst tree - * @param accVarExprId the accumulator NamedLambdaVariable's exprId - * @param elemVar the element NamedLambdaVariable (used to build the g lambda) + * @param op the chosen aggregation operator + * @param g the lifted Catalyst sub-expression for `g(x)`. For a plain + * `op(acc, g(x))` body this is the body's non-acc child; for an + * `If` / `CaseWhen` body it is rebuilt with bare-acc branches + * replaced by `op.identityLiteral` so it never references acc + * @param accVarExprId the accumulator NamedLambdaVariable's exprId + * @param elemVar the element NamedLambdaVariable (used to build the g lambda) */ case class ArrayAggregateDecomposition( op: AggOp, - gIsLeftOfMergeBody: Boolean, + g: Expression, accVarExprId: ExprId, elemVar: NamedLambdaVariable) @@ -1121,18 +1165,13 @@ object ArrayAggregateDecomposer { val body = mergeLambda.function val accId = accVar.exprId val matched = allOps.view.flatMap { op => - op.matchBinary(body).flatMap { case (l, r) => - if (isAccRef(l, accId) && !containsAccRef(r, accId)) { - Some((op, r, /* gIsLeftOfMergeBody = */ false)) - } else if (isAccRef(r, accId) && !containsAccRef(l, accId)) { - Some((op, l, /* gIsLeftOfMergeBody = */ true)) - } else None - } + extractG(body, accId, op, zeroType).map { case (g, gIsLeft) => (op, g, gIsLeft) } }.headOption - val (op, g, gIsLeft) = matched.getOrElse { + val (op, g, _) = matched.getOrElse { return Left("merge body does not match (acc, x) -> op(acc, g(x)) for any registered " + - "op (" + allOps.map(_.name).mkString(", ") + ")") + "op (" + allOps.map(_.name).mkString(", ") + "); If / CaseWhen branches must each " + + "be op-of-acc with acc on a consistent side") } if (!op.supportsType(zeroType)) { @@ -1157,7 +1196,127 @@ object ArrayAggregateDecomposer { } } - Right(ArrayAggregateDecomposition(op, gIsLeft, accId, elemVar)) + Right(ArrayAggregateDecomposition(op, g, accId, elemVar)) + } + + /** + * Try to extract g from the merge body, given a candidate op. Returns + * `Some((g, gIsLeft))` on success — `gIsLeft` is internal bookkeeping used by + * `alignSides` to ensure all branches of an If/CaseWhen agree on which side acc lives; + * it is not exposed in the final Decomposition (the lifted g already encodes the + * answer structurally). + * + * Three accepted shapes: + * 1. body is `op(acc, g)` or `op(g, acc)` — direct match. + * 2. body is `If(cond, t, f)` where each of `t`, `f` is itself accepted by this + * function (recursively) with `acc` on the *same* side, and `cond` doesn't + * reference `acc`. Lifted to `op(acc, If(cond, g_t, g_f))` (or the symmetric form + * with acc on the left). + * 3. body is `CaseWhen(branches, Some(else))` — generalised If for N branches. + * + * Bare `acc` in a branch is treated as `op(acc, identityLiteral)` — that's how we + * support the `If(cond, acc + 1, acc)` form: the right branch is bare acc, replaced + * with `acc + 0`, then the whole If lifts out as `acc + If(cond, 1, 0)`. + */ + private def extractG( + body: Expression, + accId: ExprId, + op: AggOp, + accType: DataType): Option[(Expression, Boolean)] = { + matchOpOfAcc(body, accId, op).orElse(extractFromBranching(body, accId, op, accType)) + } + + /** body matches op directly: returns `(g, gIsLeft)` if body is `op(acc, g)` / `op(g, acc)`. */ + private def matchOpOfAcc( + e: Expression, + accId: ExprId, + op: AggOp): Option[(Expression, Boolean)] = { + op.matchBinary(e).flatMap { case (l, r) => + if (isAccRef(l, accId) && !containsAccRef(r, accId)) Some((r, false)) + else if (isAccRef(r, accId) && !containsAccRef(l, accId)) Some((l, true)) + else None + } + } + + /** + * Decompose a single If/CaseWhen branch. Either it's an op-of-acc form (returns the + * non-acc side and whether acc was on the left), or it's a bare acc-ref (returns the + * op's identity literal, gIsLeft=false — the placeholder side doesn't matter, we only + * need branches to agree on it later). + * + * Recursively delegates to `extractG` so nested If is handled. + */ + private def extractBranch( + branch: Expression, + accId: ExprId, + op: AggOp, + accType: DataType): Option[(Expression, Boolean)] = { + if (isAccRef(branch, accId)) { + Some((op.identityLiteral(accType), /* gIsLeft = */ false)) + } else { + extractG(branch, accId, op, accType) + } + } + + private def extractFromBranching( + body: Expression, + accId: ExprId, + op: AggOp, + accType: DataType): Option[(Expression, Boolean)] = body match { + case If(cond, t, f) if !containsAccRef(cond, accId) => + for { + (tG, tIsLeft) <- extractBranch(t, accId, op, accType) + (fG, fIsLeft) <- extractBranch(f, accId, op, accType) + // The "bare acc" case picks gIsLeft=false. If a branch is bare acc, accept either + // side from the other branch — we'll just rebuild to that side. + gIsLeft <- alignSides(t, f, tIsLeft, fIsLeft, accId) + } yield (If(cond, tG, fG), gIsLeft) + + case CaseWhen(branches, Some(elseValue)) + if branches.forall { case (c, _) => !containsAccRef(c, accId) } => + // Decompose every (cond, val) branch + the else branch. All op-of-acc branches must + // agree on which side acc is on; bare-acc branches don't constrain. + val branchDecs = branches.map { case (c, v) => (c, extractBranch(v, accId, op, accType)) } + val elseDec = extractBranch(elseValue, accId, op, accType) + if (branchDecs.exists(_._2.isEmpty) || elseDec.isEmpty) { + None + } else { + val allBranchExprs: Seq[Expression] = branches.map(_._2) :+ elseValue + val allSides: Seq[Boolean] = branchDecs.map(_._2.get._2) :+ elseDec.get._2 + val constrainedSides = allBranchExprs.zip(allSides).collect { + case (br, side) if !isAccRef(br, accId) => side + } + if (constrainedSides.distinct.size > 1) { + None + } else { + val gIsLeft = constrainedSides.headOption.getOrElse(false) + val gBranches = branchDecs.map { case (c, dec) => (c, dec.get._1) } + Some((CaseWhen(gBranches, Some(elseDec.get._1)), gIsLeft)) + } + } + + case _ => None + } + + /** + * Reconcile the gIsLeft flags from two If branches. Bare-acc branches don't constrain + * (their identity placeholder is symmetric), so this is `agree if both constrained, + * else borrow from the constrained one`. + */ + private def alignSides( + tBranch: Expression, + fBranch: Expression, + tIsLeft: Boolean, + fIsLeft: Boolean, + accId: ExprId): Option[Boolean] = { + val tBare = isAccRef(tBranch, accId) + val fBare = isAccRef(fBranch, accId) + (tBare, fBare) match { + case (true, true) => Some(false) // both bare acc — fold has no actual op to apply + case (true, false) => Some(fIsLeft) + case (false, true) => Some(tIsLeft) + case (false, false) => if (tIsLeft == fIsLeft) Some(tIsLeft) else None + } } private def isFinishIdentity(finish: Expression): Boolean = finish match { @@ -1337,13 +1496,18 @@ class GpuArrayAggregateMeta( val argGpu = childExprs.head.convertToGpu() val zeroGpu = childExprs(1).convertToGpu() - // childExprs(2) is the merge lambda meta; its first child is the op body meta. The - // decomposer already recorded which side `g(x)` is on (gIsLeftOfMergeBody), so pick by - // index. Meta children mirror Catalyst children order, which is stable for every op - // matchBinary accepts (Add/Multiply/And/Or are commutative-shape, Greatest/Least take - // a Seq[Expression] in declaration order). - val bodyMeta = childExprs(2).childExprs.head - val gMeta = if (d.gIsLeftOfMergeBody) bodyMeta.childExprs.head else bodyMeta.childExprs(1) + // The lifted g may have a different shape from any sub-tree of the original merge body + // (If/CaseWhen branches get rewritten and identity literals get inserted), so we can't + // pick it out of childExprs(2)'s meta tree by index. Wrap g as a fresh ExprMeta and let + // spark-rapids tag/convert it. Sub-expressions inherited from the original body get + // re-tagged here, but tag is idempotent and they were already proven GPU-compatible + // when the parent ArrayAggregate was tagged. + val gMeta = GpuOverrides.wrapExpr(d.g, this.conf, Some(this)) + gMeta.tagForGpu() + if (!gMeta.canThisBeReplaced) { + throw new IllegalStateException( + s"could not convert g sub-expression ${d.g} to GPU: ${gMeta.explain(all = false)}") + } val gGpu = gMeta.convertToGpu() val elemVarGpu = GpuNamedLambdaVariable( d.elemVar.name, d.elemVar.dataType, d.elemVar.nullable, d.elemVar.exprId) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala index 2c96f2c4b44..69032fda86f 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ArrayAggregateDecomposerSuite.scala @@ -16,9 +16,9 @@ package com.nvidia.spark.rapids -import org.apache.spark.sql.catalyst.expressions.{Add, And, Cast, Divide, Expression, - Greatest, LambdaFunction, Least, Literal, Multiply, NamedExpression, NamedLambdaVariable, - Or, Subtract} +import org.apache.spark.sql.catalyst.expressions.{Add, And, CaseWhen, Cast, Divide, EqualTo, + Expression, GreaterThan, Greatest, If, LambdaFunction, Least, Literal, Multiply, + NamedExpression, NamedLambdaVariable, Or, Subtract} import org.apache.spark.sql.types.{ArrayType, BooleanType, DataType, DoubleType, IntegerType, LongType} @@ -48,15 +48,16 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { acc: NamedLambdaVariable, x: NamedLambdaVariable, expectedOp: AggOp, - expectedGIsLeft: Boolean, + expectedG: Option[Expression] = None, zeroType: DataType = IntegerType, argType: Option[DataType] = None): ArrayAggregateDecomposition = { val d = decompose(merge(body, acc, x), identityFinish(acc), argType.getOrElse(arrTy(zeroType)), zeroType) val r = d.getOrElse(fail(s"expected decomposition for body=$body, got Left: $d")) assert(r.op == expectedOp) - assert(r.gIsLeftOfMergeBody == expectedGIsLeft, - s"expected gIsLeftOfMergeBody=$expectedGIsLeft, got ${r.gIsLeftOfMergeBody}") + expectedG.foreach { g => + assert(r.g.fastEquals(g), s"expected g=$g, got ${r.g}") + } assert(r.accVarExprId == acc.exprId) assert(r.elemVar.exprId == x.exprId) r @@ -75,56 +76,56 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { test("Add(acc, x) -> SUM, g on the right") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Add(acc, x), acc, x, SumOp, expectedGIsLeft = false) + assertDecomposes(Add(acc, x), acc, x, SumOp) } test("Add(x, acc) (commuted) -> SUM, g on the left") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Add(x, acc), acc, x, SumOp, expectedGIsLeft = true) + assertDecomposes(Add(x, acc), acc, x, SumOp) } test("Multiply(acc, x) -> PRODUCT") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Multiply(acc, x), acc, x, ProductOp, expectedGIsLeft = false) + assertDecomposes(Multiply(acc, x), acc, x, ProductOp) } test("Greatest(acc, x) -> MAX") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Greatest(Seq(acc, x)), acc, x, MaxOp, expectedGIsLeft = false) + assertDecomposes(Greatest(Seq(acc, x)), acc, x, MaxOp) } test("Least(acc, x) -> MIN") { val acc = lv("acc"); val x = lv("x") - assertDecomposes(Least(Seq(acc, x)), acc, x, MinOp, expectedGIsLeft = false) + assertDecomposes(Least(Seq(acc, x)), acc, x, MinOp) } test("And(acc, x) -> ALL") { val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) - assertDecomposes(And(acc, x), acc, x, AllOp, expectedGIsLeft = false, + assertDecomposes(And(acc, x), acc, x, AllOp, zeroType = BooleanType) } test("Or(acc, x) -> ANY") { val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) - assertDecomposes(Or(acc, x), acc, x, AnyOp, expectedGIsLeft = false, + assertDecomposes(Or(acc, x), acc, x, AnyOp, zeroType = BooleanType) } test("Complex g(x) with no acc ref decomposes (g on the right)") { val acc = lv("acc", LongType); val x = lv("x", IntegerType) val g = Cast(Add(Multiply(x, Literal(2)), Literal(1)), LongType) - assertDecomposes(Add(acc, g), acc, x, SumOp, expectedGIsLeft = false, zeroType = LongType) + assertDecomposes(Add(acc, g), acc, x, SumOp, zeroType = LongType) } test("Cast wrapping the acc side is unwrapped (single layer)") { val acc = lv("acc", LongType); val x = lv("x", IntegerType) - assertDecomposes(Add(Cast(acc, IntegerType), x), acc, x, SumOp, expectedGIsLeft = false) + assertDecomposes(Add(Cast(acc, IntegerType), x), acc, x, SumOp) } test("Cast wrapping the acc side is unwrapped (chained)") { val acc = lv("acc"); val x = lv("x") val doubleCastAcc = Cast(Cast(acc, LongType), IntegerType) - assertDecomposes(Add(doubleCastAcc, x), acc, x, SumOp, expectedGIsLeft = false) + assertDecomposes(Add(doubleCastAcc, x), acc, x, SumOp) } test("Subtract is not an associative op we recognize") { @@ -228,4 +229,86 @@ class ArrayAggregateDecomposerSuite extends GpuUnitTests { zeroType = LongType) assert(msg.contains("does not match"), s"expected type-mismatch error, got: $msg") } + + // If / CaseWhen normalize: branches that are op-of-acc (or bare acc treated as + // op(acc, identity)) get lifted out so cond-driven count-if patterns run on the GPU. + + test("If(cond, acc + t, acc) decomposes to SUM (revans's pattern)") { + val acc = lv("acc"); val x = lv("x") + val body = If(EqualTo(x, Literal(7)), Add(acc, Literal(1)), acc) + assertDecomposes(body, acc, x, SumOp) + } + + test("If(cond, acc, acc + t) decomposes to SUM (commuted branches)") { + val acc = lv("acc"); val x = lv("x") + val body = If(EqualTo(x, Literal(7)), acc, Add(acc, Literal(1))) + assertDecomposes(body, acc, x, SumOp) + } + + test("If(cond, acc + t1, acc + t2) — both branches op-of-acc — decomposes") { + val acc = lv("acc"); val x = lv("x") + val body = If(GreaterThan(x, Literal(0)), Add(acc, x), Add(acc, Literal(0))) + assertDecomposes(body, acc, x, SumOp) + } + + test("If with MAX (greatest(acc, x)) on one branch and bare acc on the other") { + val acc = lv("acc"); val x = lv("x") + val body = If(GreaterThan(x, Literal(0)), Greatest(Seq(acc, x)), acc) + assertDecomposes(body, acc, x, MaxOp) + } + + test("If with And on boolean acc decomposes to ALL") { + val acc = lv("acc", BooleanType); val x = lv("x", BooleanType) + val body = If(EqualTo(x, Literal(true)), And(acc, x), acc) + assertDecomposes(body, acc, x, AllOp, + zeroType = BooleanType) + } + + test("CaseWhen with multiple acc+t branches and acc else decomposes") { + val acc = lv("acc"); val x = lv("x") + val body = CaseWhen( + Seq( + (EqualTo(x, Literal(1)), Add(acc, Literal(10))), + (EqualTo(x, Literal(2)), Add(acc, Literal(20)))), + Some(acc)) + assertDecomposes(body, acc, x, SumOp) + } + + test("If condition references acc — rejected (g must not depend on acc)") { + val acc = lv("acc"); val x = lv("x") + val body = If(GreaterThan(acc, Literal(100)), Add(acc, Literal(1)), acc) + assertRejects(merge(body, acc, x), identityFinish(acc), + "cond referencing acc breaks per-element parallelism") + } + + test("If branches use different ops — rejected") { + val acc = lv("acc"); val x = lv("x") + val body = If(GreaterThan(x, Literal(0)), Add(acc, Literal(1)), Multiply(acc, Literal(2))) + assertRejects(merge(body, acc, x), identityFinish(acc), + "branches mixing Add and Multiply have no single op to lift") + } + + test("If branches put acc on different sides — rejected") { + val acc = lv("acc"); val x = lv("x") + val body = If(GreaterThan(x, Literal(0)), Add(acc, x), Add(x, acc)) + assertRejects(merge(body, acc, x), identityFinish(acc), + "branches with acc on different sides can't share a single lifted form") + } + + test("CaseWhen without else — rejected") { + val acc = lv("acc"); val x = lv("x") + val body = CaseWhen( + Seq((EqualTo(x, Literal(1)), Add(acc, Literal(10)))), + None) + assertRejects(merge(body, acc, x), identityFinish(acc), + "CaseWhen with no else has implicit null fallthrough we don't model") + } + + test("Nested If is decomposed recursively") { + val acc = lv("acc"); val x = lv("x") + // if(c1, if(c2, acc + 1, acc + 2), acc) + val inner = If(GreaterThan(x, Literal(10)), Add(acc, Literal(1)), Add(acc, Literal(2))) + val outer = If(GreaterThan(x, Literal(0)), inner, acc) + assertDecomposes(outer, acc, x, SumOp) + } } From edcc0aa77234f43e35d23847feefe6e1bee8f5e2 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 5 May 2026 10:21:34 +0800 Subject: [PATCH 12/12] Gate float/double SUM/PRODUCT in ArrayAggregate behind variableFloatAgg.enabled cuDF's parallel tree-reduction sums in a different order than Spark's sequential left-fold, so GPU vs CPU can differ in the low bits on Float/Double. Reuse the same conf gate as scalar GpuSum/GpuAverage (spark.rapids.sql.variableFloatAgg.enabled) via GpuOverrides.checkAndTagFloatAgg in GpuArrayAggregateMeta. Default true matches the global policy. Added integration test for the conf=false fallback path. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../python/higher_order_functions_test.py | 22 +++++++++++++++++++ .../spark/rapids/higherOrderFunctions.scala | 17 ++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/integration_tests/src/main/python/higher_order_functions_test.py b/integration_tests/src/main/python/higher_order_functions_test.py index 5c411566306..1d06e64b3cd 100644 --- a/integration_tests/src/main/python/higher_order_functions_test.py +++ b/integration_tests/src/main/python/higher_order_functions_test.py @@ -264,3 +264,25 @@ def test_array_aggregate_double_extremum_falls_back(lambda_sql, init_sql): lambda spark: unary_op_df(spark, ArrayGen(double_gen, max_length=5)).selectExpr( f'aggregate(a, {init_sql}, {lambda_sql}) as res'), 'ArrayAggregate') + + +# SUM / PRODUCT on FLOAT and DOUBLE: cuDF's parallel tree-reduction sums in a different +# order than Spark's sequential left-fold, so GPU vs CPU can differ in the low bits. Gated +# by `spark.rapids.sql.variableFloatAgg.enabled` (default true, same as scalar GpuSum) — +# we only verify the fallback path here, since the GPU path under default conf accepts +# minor numeric divergence and cannot use strict-equality assertions. +@pytest.mark.parametrize('elem_gen, lambda_sql, init_sql', [ + (float_gen, '(acc, x) -> acc + x', 'CAST(0 as FLOAT)'), + (double_gen, '(acc, x) -> acc + x', 'CAST(0 as DOUBLE)'), + (float_gen, '(acc, x) -> acc * x', 'CAST(1 as FLOAT)'), + (double_gen, '(acc, x) -> acc * x', 'CAST(1 as DOUBLE)'), +], ids=['float-sum', 'double-sum', 'float-product', 'double-product']) +@disable_ansi_mode +@allow_non_gpu('ProjectExec') +def test_array_aggregate_float_sum_product_falls_back_when_variable_float_agg_disabled( + elem_gen, lambda_sql, init_sql): + assert_gpu_fallback_collect( + lambda spark: unary_op_df(spark, ArrayGen(elem_gen, max_length=5)).selectExpr( + f'aggregate(a, {init_sql}, {lambda_sql}) as res'), + 'ArrayAggregate', + conf={'spark.rapids.sql.variableFloatAgg.enabled': 'false'}) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index 00810155113..7daea466da6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -975,6 +975,10 @@ case object SumOp extends AggOp { case a: Add => Some((a.left, a.right)) case _ => None } + // Float/Double are gated behind `spark.rapids.sql.variableFloatAgg.enabled` (same conf + // as scalar GpuSum/GpuAverage) — cuDF's parallel tree-reduction sums in a different + // order than Spark's sequential left-fold, so the low-bit answer can differ even though + // both are valid IEEE 754 results. The check happens in GpuArrayAggregateMeta. def supportsType(t: DataType): Boolean = t.isInstanceOf[NumericType] } @@ -1007,7 +1011,9 @@ case object ProductOp extends AggOp { case m: Multiply => Some((m.left, m.right)) case _ => None } - // Decimal would need DecimalUtils.multiplyDecimals for overflow handling — exclude for now. + // Float/Double gated by variableFloatAgg.enabled (see SumOp). Decimal would also need + // DecimalUtils.multiplyDecimals for overflow handling — out of scope, so PRODUCT + // excludes Decimal entirely. def supportsType(t: DataType): Boolean = t match { case _: NumericType => !t.isInstanceOf[DecimalType] case _ => false @@ -1486,7 +1492,14 @@ class GpuArrayAggregateMeta( ArrayAggregateDecomposer.decompose( expr.merge, expr.finish, expr.argument.dataType, expr.zero.dataType) match { case Left(reason) => willNotWorkOnGpu(reason) - case Right(d) => decomposition = Some(d) + case Right(d) => + // SUM/PRODUCT on Float/Double diverge between cuDF's parallel tree-reduction + // and Spark's sequential left-fold. Same conf gate as GpuSum / GpuAverage — + // willNotWorkOnGpu when variableFloatAgg.enabled=false. + if (d.op == SumOp || d.op == ProductOp) { + GpuOverrides.checkAndTagFloatAgg(expr.zero.dataType, this.conf, this) + } + decomposition = Some(d) } }