Skip to content

Commit 48b8c60

Browse files
committed
initial commit
1 parent 5222f1a commit 48b8c60

File tree

2 files changed

+91
-51
lines changed

2 files changed

+91
-51
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.SQLConfHelper
21+
import org.apache.spark.sql.catalyst.expressions.{Expression, LambdaFunction, NamedLambdaVariable}
22+
import org.apache.spark.sql.catalyst.util.TypeUtils.{toSQLConf, toSQLId}
23+
import org.apache.spark.sql.internal.SQLConf
24+
import org.apache.spark.sql.types.DataType
25+
26+
/**
27+
* Object used to bind lambda function arguments to their types and validate lambda argument
28+
* constraints.
29+
*
30+
* This object creates a bound [[LambdaFunction]] by binding the arguments to the given type
31+
* information (dataType and nullability). The argument names come from the lambda function
32+
* itself. It handles three cases:
33+
*
34+
* 1. Already bound lambda functions: Returns the function as-is, assuming it has been
35+
* correctly bound to its arguments.
36+
*
37+
* 2. Unbound lambda functions: Validates and binds the function by:
38+
* - Checking that the number of arguments matches the expected count
39+
* - Checking for duplicate argument names (respecting case sensitivity configuration)
40+
* - Creating [[NamedLambdaVariable]] instances with the provided types
41+
*
42+
* 3. Non-lambda expressions: Wraps the expression in a lambda function with hidden arguments
43+
* (named `col0`, `col1`, etc.). This is used when an expression does not consume lambda
44+
* arguments but needs to be passed to a higher-order function. The arguments are hidden to
45+
* prevent accidental naming collisions.
46+
*/
47+
object LambdaBinder extends SQLConfHelper {
48+
49+
/**
50+
* Binds lambda function arguments to their types and validates lambda argument constraints.
51+
*/
52+
def apply(expression: Expression, argumentsInfo: Seq[(DataType, Boolean)]): LambdaFunction =
53+
expression match {
54+
case f: LambdaFunction if f.bound => f
55+
56+
case LambdaFunction(function, names, _) =>
57+
if (names.size != argumentsInfo.size) {
58+
expression.failAnalysis(
59+
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH",
60+
messageParameters = Map(
61+
"expectedNumArgs" -> names.size.toString,
62+
"actualNumArgs" -> argumentsInfo.size.toString
63+
)
64+
)
65+
}
66+
67+
if (names.map(a => conf.canonicalize(a.name)).distinct.size < names.size) {
68+
expression.failAnalysis(
69+
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES",
70+
messageParameters = Map(
71+
"args" -> names.map(a => conf.canonicalize(a.name)).map(toSQLId(_)).mkString(", "),
72+
"caseSensitiveConfig" -> toSQLConf(SQLConf.CASE_SENSITIVE.key)
73+
)
74+
)
75+
}
76+
77+
val arguments = argumentsInfo.zip(names).map {
78+
case ((dataType, nullable), ne) =>
79+
NamedLambdaVariable(ne.name, dataType, nullable)
80+
}
81+
LambdaFunction(function, arguments)
82+
83+
case _ =>
84+
val arguments = argumentsInfo.zipWithIndex.map {
85+
case ((dataType, nullable), i) =>
86+
NamedLambdaVariable(s"col$i", dataType, nullable)
87+
}
88+
LambdaFunction(expression, arguments, hidden = true)
89+
}
90+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323
import org.apache.spark.sql.catalyst.trees.TreePattern._
24-
import org.apache.spark.sql.catalyst.util.TypeUtils.{toSQLConf, toSQLId}
25-
import org.apache.spark.sql.internal.SQLConf
26-
import org.apache.spark.sql.types.DataType
2724

2825
/**
2926
* Resolve the lambda variables exposed by a higher order functions.
@@ -49,53 +46,6 @@ object ResolveLambdaVariables extends Rule[LogicalPlan] {
4946
}
5047
}
5148

52-
/**
53-
* Create a bound lambda function by binding the arguments of a lambda function to the given
54-
* partial arguments (dataType and nullability only). If the expression happens to be an already
55-
* bound lambda function then we assume it has been bound to the correct arguments and do
56-
* nothing. This function will produce a lambda function with hidden arguments when it is passed
57-
* an arbitrary expression.
58-
*/
59-
private def createLambda(
60-
e: Expression,
61-
argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
62-
case f: LambdaFunction if f.bound => f
63-
64-
case LambdaFunction(function, names, _) =>
65-
if (names.size != argInfo.size) {
66-
e.failAnalysis(
67-
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH",
68-
messageParameters = Map(
69-
"expectedNumArgs" -> names.size.toString,
70-
"actualNumArgs" -> argInfo.size.toString))
71-
}
72-
73-
if (names.map(a => conf.canonicalize(a.name)).distinct.size < names.size) {
74-
e.failAnalysis(
75-
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES",
76-
messageParameters = Map(
77-
"args" -> names.map(a => conf.canonicalize(a.name)).map(toSQLId(_)).mkString(", "),
78-
"caseSensitiveConfig" -> toSQLConf(SQLConf.CASE_SENSITIVE.key)))
79-
}
80-
81-
val arguments = argInfo.zip(names).map {
82-
case ((dataType, nullable), ne) =>
83-
NamedLambdaVariable(ne.name, dataType, nullable)
84-
}
85-
LambdaFunction(function, arguments)
86-
87-
case _ =>
88-
// This expression does not consume any of the lambda's arguments (it is independent). We do
89-
// create a lambda function with default parameters because this is expected by the higher
90-
// order function. Note that we hide the lambda variables produced by this function in order
91-
// to prevent accidental naming collisions.
92-
val arguments = argInfo.zipWithIndex.map {
93-
case ((dataType, nullable), i) =>
94-
NamedLambdaVariable(s"col$i", dataType, nullable)
95-
}
96-
LambdaFunction(e, arguments, hidden = true)
97-
}
98-
9949
/**
10050
* Resolve lambda variables in the expression subtree, using the passed lambda variable registry.
10151
*/
@@ -104,7 +54,7 @@ object ResolveLambdaVariables extends Rule[LogicalPlan] {
10454

10555
case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess =>
10656
SubqueryExpressionInLambdaOrHigherOrderFunctionValidator(e)
107-
h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap))
57+
h.bind(LambdaBinder(_, _)).mapChildren(resolve(_, parentLambdaMap))
10858

10959
case l: LambdaFunction if !l.bound =>
11060
SubqueryExpressionInLambdaOrHigherOrderFunctionValidator(e)

0 commit comments

Comments
 (0)