Skip to content

Commit 0885dcb

Browse files
committed
initial commit
1 parent 5222f1a commit 0885dcb

File tree

2 files changed

+98
-51
lines changed

2 files changed

+98
-51
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.SQLConfHelper
21+
import org.apache.spark.sql.catalyst.expressions.{Expression, LambdaFunction, NamedLambdaVariable}
22+
import org.apache.spark.sql.catalyst.util.TypeUtils.{toSQLConf, toSQLId}
23+
import org.apache.spark.sql.internal.SQLConf
24+
import org.apache.spark.sql.types.DataType
25+
26+
/**
27+
* Binds lambda function arguments to their types and validates lambda argument constraints.
28+
*
29+
* This object creates a bound [[LambdaFunction]] by binding the arguments to the given type
30+
* information (dataType and nullability). The argument names come from the lambda function
31+
* itself. It handles three cases:
32+
*
33+
* 1. Already bound lambda functions: Returns the function as-is, assuming it has been
34+
* correctly bound to its arguments.
35+
*
36+
* 2. Unbound lambda functions: Validates and binds the function by:
37+
* - Checking that the number of arguments matches the expected count
38+
* - Checking for duplicate argument names (respecting case sensitivity configuration)
39+
* - Creating [[NamedLambdaVariable]] instances with the provided types
40+
*
41+
* 3. Non-lambda expressions: Wraps the expression in a lambda function with hidden arguments
42+
* (named `col0`, `col1`, etc.). This is used when an expression does not consume lambda
43+
* arguments but needs to be passed to a higher-order function. The arguments are hidden to
44+
* prevent accidental naming collisions.
45+
*/
46+
object LambdaBinder extends SQLConfHelper {
47+
48+
private def canonicalizer: String => String = {
49+
if (!conf.caseSensitiveAnalysis) {
50+
// scalastyle:off caselocale
51+
s: String =>
52+
s.toLowerCase
53+
// scalastyle:on caselocale
54+
} else { s: String =>
55+
s
56+
}
57+
}
58+
59+
def apply(expression: Expression, argumentsInfo: Seq[(DataType, Boolean)]): LambdaFunction =
60+
expression match {
61+
case f: LambdaFunction if f.bound => f
62+
63+
case LambdaFunction(function, names, _) =>
64+
if (names.size != argumentsInfo.size) {
65+
expression.failAnalysis(
66+
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH",
67+
messageParameters = Map(
68+
"expectedNumArgs" -> names.size.toString,
69+
"actualNumArgs" -> argumentsInfo.size.toString
70+
)
71+
)
72+
}
73+
74+
if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) {
75+
expression.failAnalysis(
76+
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES",
77+
messageParameters = Map(
78+
"args" -> names.map(a => canonicalizer(a.name)).map(toSQLId(_)).mkString(", "),
79+
"caseSensitiveConfig" -> toSQLConf(SQLConf.CASE_SENSITIVE.key)
80+
)
81+
)
82+
}
83+
84+
val arguments = argumentsInfo.zip(names).map {
85+
case ((dataType, nullable), ne) =>
86+
NamedLambdaVariable(ne.name, dataType, nullable)
87+
}
88+
LambdaFunction(function, arguments)
89+
90+
case _ =>
91+
val arguments = argumentsInfo.zipWithIndex.map {
92+
case ((dataType, nullable), i) =>
93+
NamedLambdaVariable(s"col$i", dataType, nullable)
94+
}
95+
LambdaFunction(expression, arguments, hidden = true)
96+
}
97+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323
import org.apache.spark.sql.catalyst.trees.TreePattern._
24-
import org.apache.spark.sql.catalyst.util.TypeUtils.{toSQLConf, toSQLId}
25-
import org.apache.spark.sql.internal.SQLConf
26-
import org.apache.spark.sql.types.DataType
2724

2825
/**
2926
* Resolve the lambda variables exposed by a higher order functions.
@@ -49,53 +46,6 @@ object ResolveLambdaVariables extends Rule[LogicalPlan] {
4946
}
5047
}
5148

52-
/**
53-
* Create a bound lambda function by binding the arguments of a lambda function to the given
54-
* partial arguments (dataType and nullability only). If the expression happens to be an already
55-
* bound lambda function then we assume it has been bound to the correct arguments and do
56-
* nothing. This function will produce a lambda function with hidden arguments when it is passed
57-
* an arbitrary expression.
58-
*/
59-
private def createLambda(
60-
e: Expression,
61-
argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
62-
case f: LambdaFunction if f.bound => f
63-
64-
case LambdaFunction(function, names, _) =>
65-
if (names.size != argInfo.size) {
66-
e.failAnalysis(
67-
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NUM_ARGS_MISMATCH",
68-
messageParameters = Map(
69-
"expectedNumArgs" -> names.size.toString,
70-
"actualNumArgs" -> argInfo.size.toString))
71-
}
72-
73-
if (names.map(a => conf.canonicalize(a.name)).distinct.size < names.size) {
74-
e.failAnalysis(
75-
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.DUPLICATE_ARG_NAMES",
76-
messageParameters = Map(
77-
"args" -> names.map(a => conf.canonicalize(a.name)).map(toSQLId(_)).mkString(", "),
78-
"caseSensitiveConfig" -> toSQLConf(SQLConf.CASE_SENSITIVE.key)))
79-
}
80-
81-
val arguments = argInfo.zip(names).map {
82-
case ((dataType, nullable), ne) =>
83-
NamedLambdaVariable(ne.name, dataType, nullable)
84-
}
85-
LambdaFunction(function, arguments)
86-
87-
case _ =>
88-
// This expression does not consume any of the lambda's arguments (it is independent). We do
89-
// create a lambda function with default parameters because this is expected by the higher
90-
// order function. Note that we hide the lambda variables produced by this function in order
91-
// to prevent accidental naming collisions.
92-
val arguments = argInfo.zipWithIndex.map {
93-
case ((dataType, nullable), i) =>
94-
NamedLambdaVariable(s"col$i", dataType, nullable)
95-
}
96-
LambdaFunction(e, arguments, hidden = true)
97-
}
98-
9949
/**
10050
* Resolve lambda variables in the expression subtree, using the passed lambda variable registry.
10151
*/
@@ -104,7 +54,7 @@ object ResolveLambdaVariables extends Rule[LogicalPlan] {
10454

10555
case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess =>
10656
SubqueryExpressionInLambdaOrHigherOrderFunctionValidator(e)
107-
h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap))
57+
h.bind(LambdaBinder(_, _)).mapChildren(resolve(_, parentLambdaMap))
10858

10959
case l: LambdaFunction if !l.bound =>
11060
SubqueryExpressionInLambdaOrHigherOrderFunctionValidator(e)

0 commit comments

Comments
 (0)