From b694e2fb48b206026d71da878da1f75e1670ebf5 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Wed, 22 Oct 2025 21:23:29 -0700 Subject: [PATCH 01/15] Loosen flattening rules for sort and filter --- .../_internal/analyzer/select_statement.py | 52 +++++++++----- tests/integ/test_simplifier_suite.py | 71 ++++++++++++++----- 2 files changed, 89 insertions(+), 34 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 723277a31d..b6d92411d7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -20,6 +20,7 @@ Sequence, Set, Union, + Literal, ) import snowflake.snowpark._internal.utils @@ -1362,9 +1363,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement": ): # TODO: Clean up, this entire if case is parameter protection can_be_flattened = False - elif (self.where or self.order_by or self.limit_) and has_data_generator_exp( - cols - ): + elif ( + self.where or self.order_by or self.limit_ + ) and has_data_generator_or_window_function_exp(cols): can_be_flattened = False elif self.where and ( (subquery_dependent_columns := derive_dependent_columns(self.where)) @@ -1453,9 +1454,9 @@ def filter(self, col: Expression) -> "SelectStatement": can_be_flattened = ( (not self.flatten_disabled) and can_clause_dependent_columns_flatten( - derive_dependent_columns(col), self.column_states + derive_dependent_columns(col), self.column_states, "filter" ) - and not has_data_generator_exp(self.projection) + and not has_data_generator_or_window_function_exp(self.projection) and not (self.order_by and self.limit_ is not None) ) if can_be_flattened: @@ -1490,7 +1491,7 @@ def sort(self, cols: List[Expression]) -> "SelectStatement": and (not self.limit_) and (not self.offset) and can_clause_dependent_columns_flatten( - derive_dependent_columns(*cols), self.column_states + derive_dependent_columns(*cols), self.column_states, "sort" ) and not has_data_generator_exp(self.projection) ) @@ -1529,7 +1530,7 @@ def distinct(self) -> "SelectStatement": # .order_by(col1).select(col2).distinct() cannot be flattened because # SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL and (not (self.order_by and self.has_projection)) - and not has_data_generator_exp(self.projection) + and not has_data_generator_or_window_function_exp(self.projection) ) if can_be_flattened: new = copy(self) @@ -2020,7 +2021,12 @@ def can_projection_dependent_columns_be_flattened( def can_clause_dependent_columns_flatten( dependent_columns: Optional[AbstractSet[str]], subquery_column_states: ColumnStateDict, + clause: Literal["filter", "sort"], ) -> bool: + if clause not in ["filter", "sort"]: + raise ValueError( + f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}" + ) if dependent_columns == COLUMN_DEPENDENCY_DOLLAR: return False elif ( @@ -2034,15 +2040,10 @@ def can_clause_dependent_columns_flatten( for dc in dependent_columns: dc_state = subquery_column_states.get(dc) if dc_state: - if dc_state.change_state == ColumnChangeState.CHANGED_EXP: - return False - elif dc_state.change_state == ColumnChangeState.NEW: - # Most of the time this can be flattened. But if a new column uses window function and this column - # is used in a clause, the sql doesn't work in Snowflake. - # For instance `select a, rank() over(order by b) as d from test_table where d = 1` doesn't work. - # But `select a, b as d from test_table where d = 1` works - # We can inspect whether the referenced new column uses window function. Here we are being - # conservative for now to not flatten the SQL. + if ( + dc_state.change_state == ColumnChangeState.CHANGED_EXP + and clause == "filter" + ): return False return True @@ -2264,8 +2265,6 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: if expressions is None: return False for exp in expressions: - if isinstance(exp, WindowExpression): - return True if isinstance(exp, FunctionExpression) and ( exp.is_data_generator or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION @@ -2275,3 +2274,20 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: if exp is not None and has_data_generator_exp(exp.children): return True return False + + +def has_window_function_exp(expressions: Optional[List["Expression"]]) -> bool: + if expressions is None: + return False + for exp in expressions: + if isinstance(exp, WindowExpression): + return True + if exp is not None and has_window_function_exp(exp.children): + return True + return False + + +def has_data_generator_or_window_function_exp( + expressions: Optional[List["Expression"]], +) -> bool: + return has_data_generator_exp(expressions) or has_window_function_exp(expressions) diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index f42cb176cd..55942d31ed 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -9,7 +9,7 @@ import pytest -from snowflake.snowpark import Row +from snowflake.snowpark import Row, Window from snowflake.snowpark._internal.analyzer.select_statement import ( SET_EXCEPT, SET_INTERSECT, @@ -30,6 +30,7 @@ sum as sum_, table_function, udtf, + rank, ) from tests.utils import TestData, Utils @@ -754,21 +755,35 @@ def test_order_by(setup_reduce_cast, session, simplifier_table): f'SELECT "A", "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) - # no flatten because c is a new column + # flatten if a new column is used in the order by clause df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).sort("a", "b", "c") assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' + f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' ) - # no flatten because a and be are changed + # still flatten even if a is changed because it's used in the order by clause df4 = df.select((col("a") + 1).as_("a"), ((col("b") + 1).as_("b"))).sort("a", "b") assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + f'SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) - # subquery has sql text so unable to figure out same-level dependency, so assuming d depends on c. No flatten. - df5 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).sort("a", "b") + # still flatten if a window function is used in the projection + df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).sort( + "a", "b" + ) assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) + + # No flatten if a data generator is used in the projection + df6 = df.select("a", "b", seq1().alias("c")).sort("a", "b") + assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", "B", seq1(0) AS "C" FROM {simplifier_table}) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) + + # subquery has sql text so unable to figure out if a data generator is used in the projection. No flatten. + df7 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).sort("a", "b") + assert Utils.normalize_sql(df7.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) @@ -791,32 +806,56 @@ def test_filter(setup_reduce_cast, session, simplifier_table): f'SELECT "A", "B" FROM {simplifier_table} WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' ) - # no flatten because c is a new column + # flatten if a regular new column is in the projection df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter( - (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + (col("a") > 1) & (col("b") > 2) ) assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' + ) + + # flatten if a regular new column is used in the filter clause + df4 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter( + (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + ) + assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + ) + + # no flatten if a window function is used in the projection + df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).filter( + (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + ) + assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + ) + + # no flatten if a data generator is used in the projection + df6 = df.select("a", "b", seq1().alias("c")).filter( + (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + ) + assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", "B", seq1(0) AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' ) # no flatten because a and be are changed - df4 = df.select((col("a") + 1).as_("a"), (col("b") + 1).as_("b")).filter( + df7 = df.select((col("a") + 1).as_("a"), (col("b") + 1).as_("b")).filter( (col("a") > 1) & (col("b") > 2) ) - assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + assert Utils.normalize_sql(df7.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' ) - df5 = df4.select("a") - assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + df8 = df7.select("a") + assert Utils.normalize_sql(df8.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT "A" FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' ) # subquery has sql text so unable to figure out same-level dependency, so assuming d depends on c. No flatten. - df6 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).filter( + df9 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).filter( col("a") > 1 ) - assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql( + assert Utils.normalize_sql(df9.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) WHERE ("A" > 1{integer_literal_postfix})' ) From a25942b13dac2303da4d423e10d563dc9f38a9c4 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Mon, 27 Oct 2025 14:02:16 -0700 Subject: [PATCH 02/15] No flatten if dropped columns after order by / sort --- .../_internal/analyzer/select_statement.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index b6d92411d7..40467d8fba 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1376,7 +1376,18 @@ def select(self, cols: List[Expression]) -> "SelectStatement": subquery_dependent_columns & new_column_states.active_columns ) ) + or ( + new_column_states.dropped_columns + and any( + new_column_states[_col].change_state == ColumnChangeState.DROPPED + and self.column_states[_col].change_state + in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) + and _col in subquery_dependent_columns + for _col in (new_column_states.dropped_columns) + ) + ) ): + # or (new_column_states[_col].change_state == ColumnChangeState.DROPPED and self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)) can_be_flattened = False elif self.order_by and ( (subquery_dependent_columns := derive_dependent_columns(*self.order_by)) @@ -1388,6 +1399,16 @@ def select(self, cols: List[Expression]) -> "SelectStatement": subquery_dependent_columns & new_column_states.active_columns ) ) + or ( + new_column_states.dropped_columns + and any( + new_column_states[_col].change_state == ColumnChangeState.DROPPED + and self.column_states[_col].change_state + in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) + and _col in subquery_dependent_columns + for _col in (new_column_states.dropped_columns) + ) + ) ): can_be_flattened = False elif self.distinct_: From 1df1e96fa7c30979e5feacc6c91d7b3f176289c4 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Mon, 27 Oct 2025 23:46:30 -0700 Subject: [PATCH 03/15] Add tests in simplifier --- tests/integ/test_simplifier_suite.py | 79 ++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 55942d31ed..bc9d08b068 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -1650,6 +1650,22 @@ def test_chained_sort(session): .filter(col("A") > 2), 'SELECT "A", "B", 12 :: INT AS "TWELVE" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE (("A" > 1{POSTFIX}) AND ("A" > 2{POSTFIX}))', ), + # Flattened if the dropped columns are not used in filter + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A").alias("C"), col("B").alias("D")) + .filter(col("C") > 2) + .select(col("C")), + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("C" > 2{POSTFIX}))', + ), + # Flattened if the dropped columns are not in the filter clause's dependent columns + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A").alias("C"), col("B").alias("D")) + .filter((col("C") + 1) > 2) + .select(col("C")), + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("C" + 1{POSTFIX}) > 2{POSTFIX}))', + ), # Not fully flattened, since col("A") > 1 and col("A") > 2 are referring to different columns ( lambda df: df.filter(col("A") > 1) @@ -1672,6 +1688,29 @@ def test_chained_sort(session): lambda df: df.filter(col("$1") > 1).select(col("B"), col("A")), 'SELECT "B", "A" FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ) WHERE ("$1" > 1{POSTFIX}) )', ), + # Not flattened if a dropped column is used in the filter clause + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A"), col("B").alias("D")) + .filter(col("D") > -3) + .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("D" > -3{POSTFIX})))', + ), + # Not flattened if a dropped column is used in the select clause's dependent columns + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A"), col("B").alias("D")) + .filter((col("D") - 1) > -4) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("D" - 1{POSTFIX}) > -4{POSTFIX})))', + ), + # Not flattened if a dropped column that was changed expression is used in the select clause's dependent columns + ( + lambda df: df.select(col("A"), (col("B") + 1).alias("B")) + .filter((col("B") - 1) > -4) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + ), ], ) def test_select_after_filter(setup_reduce_cast, session, operation, simplified_query): @@ -1742,6 +1781,46 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q 'SELECT "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "B" ASC NULLS FIRST, "A" ASC NULLS FIRST', True, ), + # Flattened if the dropped columns are not used in filter + ( + lambda df: df.select(col("A").alias("C"), col("B").alias("D")) + .order_by(col("C")) + .select(col("C")), + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "C" ASC NULLS FIRST', + True, + ), + # Flattened if the dropped columns are not in the order by clause's dependent columns + ( + lambda df: df.select(col("A").alias("C"), col("B").alias("D")) + .order_by(col("C") + 1) + .select(col("C")), + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("C" + 1{POSTFIX}) ASC NULLS FIRST', + True, + ), + # Not flattened if a dropped new column is used in the order by clause + ( + lambda df: df.select(col("A"), col("B").alias("D")) + .order_by(col("D")) + .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST)', + True, + ), + # Not flattened if a dropped new column is used in the order by clause's dependent columns + ( + lambda df: df.select(col("A"), col("B").alias("D")) + .order_by(col("D") - 1) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST)', + True, + ), + # Not flattened if a dropped column that was changed expression is used in the select clause's dependent columns + ( + lambda df: df.select(col("A"), (col("B") + 1).alias("B")) + .filter((col("B") - 1) > -4) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + True, + ), ], ) def test_select_after_orderby( From 11043611e01874f80f8084a7edb0561853813ae9 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Tue, 28 Oct 2025 10:53:52 -0700 Subject: [PATCH 04/15] Update test after flattening --- tests/integ/test_query_line_intervals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_query_line_intervals.py b/tests/integ/test_query_line_intervals.py index 2852a7c9e3..d60f951bf6 100644 --- a/tests/integ/test_query_line_intervals.py +++ b/tests/integ/test_query_line_intervals.py @@ -73,7 +73,7 @@ def generate_test_data(session, sql_simplifier_enabled): lambda data: data["df1"].filter(data["df1"].value > 150), True, { - 8: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', + 8: """SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM (SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, 'A' :: STRING, 100 :: INT), (2 :: INT, 'B' :: STRING, 200 :: INT)) WHERE ("VALUE" > 150)""", }, ), ( From d67f5df9039fc6997c1853f67667449be0a23ce4 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Wed, 29 Oct 2025 23:49:29 -0700 Subject: [PATCH 05/15] parameter protection and agg function check for fitler --- .../_internal/analyzer/select_statement.py | 79 ++++++++++++++----- src/snowflake/snowpark/context.py | 1 + src/snowflake/snowpark/session.py | 14 ++++ 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 40467d8fba..36bb6f125d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -87,6 +87,7 @@ is_sql_select_statement, ExprAliasUpdateDict, ) +import snowflake.snowpark.context as context # Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable # Python 3.9 can use both @@ -1377,17 +1378,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement": ) ) or ( - new_column_states.dropped_columns + # unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery + # reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error + # sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1' + context._is_snowpark_connect_compatible_mode + and new_column_states.dropped_columns and any( - new_column_states[_col].change_state == ColumnChangeState.DROPPED - and self.column_states[_col].change_state + self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) - and _col in subquery_dependent_columns - for _col in (new_column_states.dropped_columns) + for _col in ( + subquery_dependent_columns & new_column_states.dropped_columns + ) ) ) ): - # or (new_column_states[_col].change_state == ColumnChangeState.DROPPED and self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)) can_be_flattened = False elif self.order_by and ( (subquery_dependent_columns := derive_dependent_columns(*self.order_by)) @@ -1400,13 +1404,17 @@ def select(self, cols: List[Expression]) -> "SelectStatement": ) ) or ( - new_column_states.dropped_columns + # unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery + # reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error + # sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"' + context._is_snowpark_connect_compatible_mode + and new_column_states.dropped_columns and any( - new_column_states[_col].change_state == ColumnChangeState.DROPPED - and self.column_states[_col].change_state + self.column_states[_col].change_state in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) - and _col in subquery_dependent_columns - for _col in (new_column_states.dropped_columns) + for _col in ( + subquery_dependent_columns & new_column_states.dropped_columns + ) ) ) ): @@ -1478,6 +1486,10 @@ def filter(self, col: Expression) -> "SelectStatement": derive_dependent_columns(col), self.column_states, "filter" ) and not has_data_generator_or_window_function_exp(self.projection) + and not ( + context._is_snowpark_connect_compatible_mode + and has_aggregation_function_exp(self.projection) + ) # sum(col) as new_col, new_col can not be flattened in where clause and not (self.order_by and self.limit_ is not None) ) if can_be_flattened: @@ -2044,10 +2056,10 @@ def can_clause_dependent_columns_flatten( subquery_column_states: ColumnStateDict, clause: Literal["filter", "sort"], ) -> bool: - if clause not in ["filter", "sort"]: - raise ValueError( - f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}" - ) + assert clause in ( + "filter", + "sort", + ), f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}" if dependent_columns == COLUMN_DEPENDENCY_DOLLAR: return False elif ( @@ -2061,11 +2073,19 @@ def can_clause_dependent_columns_flatten( for dc in dependent_columns: dc_state = subquery_column_states.get(dc) if dc_state: - if ( - dc_state.change_state == ColumnChangeState.CHANGED_EXP - and clause == "filter" - ): - return False + if dc_state.change_state == ColumnChangeState.CHANGED_EXP: + if ( + clause == "filter" + ): # where can not be flattened because 'where' is evaluated before projection, flattening leads to wrong result + # df.select((col('a') + 1).alias('a')).filter(col('a') > 5) -- this should be applied to the new 'a', flattening will use the old 'a' to evaluated + return False + else: # clause == 'sort' + # df.select((col('a') + 1).alias('a')).sort(col('a')) -- this is valid to flatten because 'order by' is evaluated after projection + # however, if the order by is a data generator, it should not be flattened because generator is evaluated dynamically according to the order. + return context._is_snowpark_connect_compatible_mode + elif dc_state.change_state == ColumnChangeState.NEW: + return context._is_snowpark_connect_compatible_mode + return True @@ -2286,6 +2306,10 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: if expressions is None: return False for exp in expressions: + if not context._is_snowpark_connect_compatible_mode and isinstance( + exp, WindowExpression + ): + return True if isinstance(exp, FunctionExpression) and ( exp.is_data_generator or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION @@ -2311,4 +2335,19 @@ def has_window_function_exp(expressions: Optional[List["Expression"]]) -> bool: def has_data_generator_or_window_function_exp( expressions: Optional[List["Expression"]], ) -> bool: + if not context._is_snowpark_connect_compatible_mode: + return has_data_generator_exp(expressions) return has_data_generator_exp(expressions) or has_window_function_exp(expressions) + + +def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool: + if expressions is None: + return False + for exp in expressions: + if isinstance(exp, FunctionExpression) and ( + exp.name.lower() in context._aggregation_function_set + ): + return True + if exp is not None and has_aggregation_function_exp(exp.children): + return True + return False diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index 86e92b6aa4..cffd79fc52 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -31,6 +31,7 @@ # This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect _is_snowpark_connect_compatible_mode = False +_aggregation_function_set = set() # Following are internal-only global flags, used to enable development features. _enable_dataframe_trace_on_error = False diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1afe626720..1067fd4471 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -518,6 +518,20 @@ def create(self) -> "Session": _add_session(session) else: session = self._create_internal(self._options.get("connection")) + if context._is_snowpark_connect_compatible_mode: + for sql in [ + """select function_name from information_schema.functions where is_aggregate = 'YES'""", + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + ]: + try: + context._aggregation_function_set.update( + {r[0] for r in session.sql(sql).collect()} + ) + except BaseException as e: + _logger.debug( + "Unable to get aggregation functions from the database: %s", + e, + ) if self._app_name: if self._format_json: From 43c3316fa4f6189bbd8caee271a3bce6514d704f Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 31 Oct 2025 12:12:59 -0700 Subject: [PATCH 06/15] fix tests --- .../_internal/analyzer/select_statement.py | 4 + src/snowflake/snowpark/context.py | 5 +- src/snowflake/snowpark/session.py | 39 ++-- tests/integ/test_query_line_intervals.py | 24 ++- tests/integ/test_simplifier_suite.py | 176 +++++++++++++++--- 5 files changed, 207 insertions(+), 41 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 36bb6f125d..c0bf4ce207 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1480,6 +1480,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement": return new def filter(self, col: Expression) -> "SelectStatement": + self._session._retrieve_aggregation_function_list() can_be_flattened = ( (not self.flatten_disabled) and can_clause_dependent_columns_flatten( @@ -1527,6 +1528,9 @@ def sort(self, cols: List[Expression]) -> "SelectStatement": derive_dependent_columns(*cols), self.column_states, "sort" ) and not has_data_generator_exp(self.projection) + # we do not check aggregation function here like filter + # in the case when aggregation function is in the projection + # order by is evaluated after aggregation, row info are not taken in the calculation ) if can_be_flattened: new = copy(self) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index cffd79fc52..ed1d15c5f2 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -31,7 +31,10 @@ # This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect _is_snowpark_connect_compatible_mode = False -_aggregation_function_set = set() +_aggregation_function_set = ( + set() +) # lower cased names of aggregation functions, used in sql simplification +_aggregation_function_set_lock = threading.RLock() # Following are internal-only global flags, used to enable development features. _enable_dataframe_trace_on_error = False diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 1067fd4471..6fea3308ae 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -518,20 +518,6 @@ def create(self) -> "Session": _add_session(session) else: session = self._create_internal(self._options.get("connection")) - if context._is_snowpark_connect_compatible_mode: - for sql in [ - """select function_name from information_schema.functions where is_aggregate = 'YES'""", - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", - ]: - try: - context._aggregation_function_set.update( - {r[0] for r in session.sql(sql).collect()} - ) - except BaseException as e: - _logger.debug( - "Unable to get aggregation functions from the database: %s", - e, - ) if self._app_name: if self._format_json: @@ -4874,6 +4860,31 @@ def _execute_sproc_internal( # Note the collect is implicit within the stored procedure call, so should not emit_ast here. return df.collect(statement_params=statement_params, _emit_ast=False)[0][0] + def _retrieve_aggregation_function_list(self) -> None: + """Retrieve the list of aggregation functions which will later be used in sql simplifier.""" + if ( + not context._is_snowpark_connect_compatible_mode + or context._aggregation_function_set + ): + return + + retrieved_set = set() + + for sql in [ + """select function_name from information_schema.functions where is_aggregate = 'YES'""", + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", + ]: + try: + retrieved_set.update({r[0].lower() for r in self.sql(sql).collect()}) + except BaseException as e: + _logger.debug( + "Unable to get aggregation functions from the database: %s", + e, + ) + + with context._aggregation_function_set_lock: + context._aggregation_function_set.update(retrieved_set) + def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ Returns a DataFrame representing the results of a directory table query on the specified stage. diff --git a/tests/integ/test_query_line_intervals.py b/tests/integ/test_query_line_intervals.py index d60f951bf6..a95fec4855 100644 --- a/tests/integ/test_query_line_intervals.py +++ b/tests/integ/test_query_line_intervals.py @@ -57,8 +57,9 @@ def generate_test_data(session, sql_simplifier_enabled): } +@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False]) @pytest.mark.parametrize( - "op,sql_simplifier,line_to_expected_sql", + "op,sql_simplifier,line_to_expected_sql,snowpark_connect_compatible_mode_sql", [ ( lambda data: data["df1"].union(data["df2"]), @@ -68,10 +69,14 @@ def generate_test_data(session, sql_simplifier_enabled): 6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', 10: 'SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (3 :: INT, \'C\' :: STRING, 300 :: INT), (4 :: INT, \'D\' :: STRING, 400 :: INT) )', }, + None, ), ( lambda data: data["df1"].filter(data["df1"].value > 150), True, + { + 8: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)' + }, { 8: """SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM (SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, 'A' :: STRING, 100 :: INT), (2 :: INT, 'B' :: STRING, 200 :: INT)) WHERE ("VALUE" > 150)""", }, @@ -83,6 +88,7 @@ def generate_test_data(session, sql_simplifier_enabled): 1: 'SELECT "_1" AS "ID", "_2" AS "NAME" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) )', 4: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', }, + None, ), ( lambda data: data["df1"].pivot(F.col("name")).sum(F.col("value")), @@ -92,12 +98,26 @@ def generate_test_data(session, sql_simplifier_enabled): 6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', 9: 'SELECT * FROM ( SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) ) ) PIVOT ( sum("VALUE") FOR "NAME" IN ( ANY ) )', }, + None, ), ], ) def test_get_plan_from_line_numbers_sql_content( - session, op, sql_simplifier, line_to_expected_sql + session, + op, + sql_simplifier, + line_to_expected_sql, + snowpark_connect_compatible_mode_sql, + snowpark_connect_compatible_mode, + monkeypatch, ): + if snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + line_to_expected_sql = ( + snowpark_connect_compatible_mode_sql or line_to_expected_sql + ) session.sql_simplifier_enabled = sql_simplifier df = op(generate_test_data(session, sql_simplifier)) diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index bc9d08b068..044e16f4f8 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -737,7 +737,19 @@ def test_reference_non_exist_columns(session, simplifier_table): df.select(col("c") + 1).collect() -def test_order_by(setup_reduce_cast, session, simplifier_table): +@pytest.mark.parametrize("is_snowpark_connect_compatible_mode", [True, False]) +def test_order_by( + setup_reduce_cast, + session, + simplifier_table, + is_snowpark_connect_compatible_mode, + monkeypatch, +): + if is_snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + df = session.table(simplifier_table) # flatten @@ -755,24 +767,42 @@ def test_order_by(setup_reduce_cast, session, simplifier_table): f'SELECT "A", "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) - # flatten if a new column is used in the order by clause + # snowpark connect compatible mode: flatten if a new column is used in the order by clause + # snowflake mode: no flatten because c is a new column df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).sort("a", "b", "c") - assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql = ( f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' + ) + assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql ) - # still flatten even if a is changed because it's used in the order by clause + # snowpark connect compatible mode: flatten even if a is changed because it's used in the order by clause + # snowflake mode: no flatten because a and be are changed df4 = df.select((col("a") + 1).as_("a"), ((col("b") + 1).as_("b"))).sort("a", "b") - assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql = ( f'SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) + assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql ) - # still flatten if a window function is used in the projection + # snowpark connect compatible mode: flatten if a window function is used in the projection + # snowflake mode: no flatten because c is a new column df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).sort( "a", "b" ) - assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql = ( f'SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) + assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql ) # No flatten if a data generator is used in the projection @@ -787,8 +817,30 @@ def test_order_by(setup_reduce_cast, session, simplifier_table): f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) + df8 = df.select("a", sum_(col("b")).alias("c")).sort("c") + compare_sql = ( + f'SELECT "A", sum("B") AS "C" FROM {simplifier_table} ORDER BY "C" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", sum("B") AS "C" FROM {simplifier_table} ) ORDER BY "C" ASC NULLS FIRST' + ) + assert Utils.normalize_sql(df8.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql + ) + + +@pytest.mark.parametrize("is_snowpark_connect_compatible_mode", [True, False]) +def test_filter( + setup_reduce_cast, + session, + simplifier_table, + is_snowpark_connect_compatible_mode, + monkeypatch, +): + if is_snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) -def test_filter(setup_reduce_cast, session, simplifier_table): df = session.table(simplifier_table) integer_literal_postfix = ( "" if session.eliminate_numeric_sql_value_cast_enabled else " :: INT" @@ -808,18 +860,15 @@ def test_filter(setup_reduce_cast, session, simplifier_table): # flatten if a regular new column is in the projection df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter( - (col("a") > 1) & (col("b") > 2) - ) - assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' - ) - - # flatten if a regular new column is used in the filter clause - df4 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter( (col("a") > 1) & (col("b") > 2) & (col("c") < 1) ) - assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql = ( f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + ) + assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql ) # no flatten if a window function is used in the projection @@ -859,6 +908,12 @@ def test_filter(setup_reduce_cast, session, simplifier_table): f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) WHERE ("A" > 1{integer_literal_postfix})' ) + # no flatten if a aggregation function is used in the projection + df10 = df.select("a", sum_(col("b")).alias("c")).filter(col("c") < 1) + assert Utils.normalize_sql(df10.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", sum("B") AS "C" FROM {simplifier_table} ) WHERE ("C" < 1{integer_literal_postfix})' + ) + def test_limit(setup_reduce_cast, session, simplifier_table): df = session.table(simplifier_table) @@ -1630,18 +1685,21 @@ def test_chained_sort(session): ) +@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False]) @pytest.mark.parametrize( - "operation,simplified_query", + "operation,simplified_query,snowpark_connect_simplified_query", [ # Flattened ( lambda df: df.filter(col("A") > 1).select(col("B") + 1), 'SELECT ("B" + 1{POSTFIX}) FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX})', + None, ), # Flattened, if there are duplicate column names across the parent/child, WHERE is evaluated on subquery first, so we could flatten in this case ( lambda df: df.filter(col("A") > 1).select((col("B") + 1).alias("A")), 'SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX})', + None, ), # Flattened ( @@ -1649,21 +1707,26 @@ def test_chained_sort(session): .select(col("A"), col("B"), lit(12).alias("TWELVE")) .filter(col("A") > 2), 'SELECT "A", "B", 12 :: INT AS "TWELVE" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE (("A" > 1{POSTFIX}) AND ("A" > 2{POSTFIX}))', + None, ), - # Flattened if the dropped columns are not used in filter + # Flattened if the dropped columns are not used in filter in snowpark connect compatible mode + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query ( lambda df: df.filter(col("A") >= 1) .select(col("A").alias("C"), col("B").alias("D")) .filter(col("C") > 2) .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE ("C" > 2{POSTFIX})', 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("C" > 2{POSTFIX}))', ), - # Flattened if the dropped columns are not in the filter clause's dependent columns + # Flattened if the dropped columns are not in the filter clause's dependent columns in snowpark connect compatible mode + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query ( lambda df: df.filter(col("A") >= 1) .select(col("A").alias("C"), col("B").alias("D")) .filter((col("C") + 1) > 2) .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE (("C" + 1{POSTFIX}) > 2{POSTFIX})', 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("C" + 1{POSTFIX}) > 2{POSTFIX}))', ), # Not fully flattened, since col("A") > 1 and col("A") > 2 are referring to different columns @@ -1672,36 +1735,44 @@ def test_chained_sort(session): .select((col("B") + 1).alias("A")) .filter(col("A") > 2), 'SELECT * FROM ( SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX}) ) WHERE ("A" > 2{POSTFIX})', + None, ), # Not flattened, since A is updated in the select after filter. ( lambda df: df.filter(col("A") > 1).select("A", seq1(0)), 'SELECT "A", seq1(0) FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX}) )', + None, ), # Not flattened, since we cannot detect dependent columns from sql_expr ( lambda df: df.filter(sql_expr("A > 1")).select(col("B"), col("A")), 'SELECT "B", "A" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE A > 1 )', + None, ), # Not flattened, since we cannot flatten when the subquery uses positional parameter ($1) ( lambda df: df.filter(col("$1") > 1).select(col("B"), col("A")), 'SELECT "B", "A" FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ) WHERE ("$1" > 1{POSTFIX}) )', + None, ), # Not flattened if a dropped column is used in the filter clause + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query ( lambda df: df.filter(col("A") >= 1) .select(col("A"), col("B").alias("D")) .filter(col("D") > -3) .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE ("D" > -3{POSTFIX})', 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("D" > -3{POSTFIX})))', ), # Not flattened if a dropped column is used in the select clause's dependent columns + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query ( lambda df: df.filter(col("A") >= 1) .select(col("A"), col("B").alias("D")) .filter((col("D") - 1) > -4) .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE (("D" - 1{POSTFIX}) > -4{POSTFIX})', 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("D" - 1{POSTFIX}) > -4{POSTFIX})))', ), # Not flattened if a dropped column that was changed expression is used in the select clause's dependent columns @@ -1710,10 +1781,25 @@ def test_chained_sort(session): .filter((col("B") - 1) > -4) .select((col("A") + 1).alias("E")), 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + None, ), ], ) -def test_select_after_filter(setup_reduce_cast, session, operation, simplified_query): +def test_select_after_filter( + setup_reduce_cast, + session, + operation, + simplified_query, + snowpark_connect_compatible_mode, + monkeypatch, + snowpark_connect_simplified_query, +): + if snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + simplified_query = snowpark_connect_simplified_query or simplified_query + session.sql_simplifier_enabled = False df1 = session.create_dataframe([[1, -2], [3, -4]], schema=["a", "b"]) @@ -1733,43 +1819,50 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q ) == Utils.normalize_sql(simplified_query) +@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False]) @pytest.mark.parametrize( - "operation,simplified_query,execute_sql", + "operation,simplified_query,snowpark_connect_simplified_query,execute_sql", [ # Flattened ( lambda df: df.order_by(col("A")).select(col("B") + 1), 'SELECT ("B" + 1{POSTFIX}) FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST', + None, True, ), # Not flattened because SEQ1() is a data generator. ( lambda df: df.order_by(col("A")).select(seq1(0)), 'SELECT seq1(0) FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST )', + None, True, ), # Not flattened, unlike filter, current query takes precendence when there are duplicate column names from a ORDERBY clause ( lambda df: df.order_by(col("A")).select((col("B") + 1).alias("A")), 'SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST )', + None, True, ), # Not flattened, since we cannot detect dependent columns from sql_expr ( lambda df: df.order_by(sql_expr("A")).select(col("B")), 'SELECT "B" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY A ASC NULLS FIRST )', + None, True, ), # Not flattened, since we cannot flatten when the subquery uses positional parameter ($1) ( lambda df: df.order_by(col("$1")).select(col("B")), 'SELECT "B" FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ) ORDER BY "$1" ASC NULLS FIRST )', + None, True, ), # Not flattened, skip execution since this would result in SnowparkSQLException ( lambda df: df.order_by(col("C")).select((col("A") + col("B")).alias("C")), 'SELECT ("A" + "B") AS "C" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "C" ASC NULLS FIRST )', + None, False, ), # Flattened @@ -1779,13 +1872,15 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q .order_by(col("B")) .select(col("A")), 'SELECT "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "B" ASC NULLS FIRST, "A" ASC NULLS FIRST', + None, True, ), - # Flattened if the dropped columns are not used in filter + # Flattened if the dropped columns are not used in filter in the snowpark connect compatible mode ( lambda df: df.select(col("A").alias("C"), col("B").alias("D")) .order_by(col("C")) .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "C" ASC NULLS FIRST', 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "C" ASC NULLS FIRST', True, ), @@ -1794,6 +1889,7 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q lambda df: df.select(col("A").alias("C"), col("B").alias("D")) .order_by(col("C") + 1) .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY ("C" + 1{POSTFIX}) ASC NULLS FIRST', 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("C" + 1{POSTFIX}) ASC NULLS FIRST', True, ), @@ -1802,6 +1898,7 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q lambda df: df.select(col("A"), col("B").alias("D")) .order_by(col("D")) .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "D" ASC NULLS FIRST', 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST)', True, ), @@ -1810,6 +1907,7 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q lambda df: df.select(col("A"), col("B").alias("D")) .order_by(col("D") - 1) .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST', 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST)', True, ), @@ -1819,13 +1917,27 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q .filter((col("B") - 1) > -4) .select((col("A") + 1).alias("E")), 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + None, True, ), ], ) def test_select_after_orderby( - setup_reduce_cast, session, operation, simplified_query, execute_sql + setup_reduce_cast, + session, + operation, + simplified_query, + execute_sql, + snowpark_connect_compatible_mode, + monkeypatch, + snowpark_connect_simplified_query, ): + if snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + simplified_query = snowpark_connect_simplified_query or simplified_query + session.sql_simplifier_enabled = False df1 = session.create_dataframe([[1, -2], [3, -4]], schema=["a", "b"]) @@ -2012,3 +2124,19 @@ def test_select_distinct( ) finally: session.conf.set("use_simplified_query_generation", original) + + +def test_retrieving_aggregation_funcs(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + assert not context._aggregation_function_set + session._retrieve_aggregation_function_list() + assert context._aggregation_function_set + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", False) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + assert not context._aggregation_function_set + session._retrieve_aggregation_function_list() + assert not context._aggregation_function_set From 3f7a98e08e2985888fcdf60203d3f6e038e80013 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 31 Oct 2025 13:18:06 -0700 Subject: [PATCH 07/15] fix local testing --- src/snowflake/snowpark/mock/_select_statement.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/mock/_select_statement.py b/src/snowflake/snowpark/mock/_select_statement.py index d21a86aeda..2a149db59e 100644 --- a/src/snowflake/snowpark/mock/_select_statement.py +++ b/src/snowflake/snowpark/mock/_select_statement.py @@ -412,7 +412,7 @@ def filter(self, col: Expression) -> "MockSelectStatement": else: dependent_columns = derive_dependent_columns(col) can_be_flattened = can_clause_dependent_columns_flatten( - dependent_columns, self.column_states + dependent_columns, self.column_states, "filter" ) if can_be_flattened: new = copy(self) @@ -433,7 +433,7 @@ def sort(self, cols: List[Expression]) -> "MockSelectStatement": else: dependent_columns = derive_dependent_columns(*cols) can_be_flattened = can_clause_dependent_columns_flatten( - dependent_columns, self.column_states + dependent_columns, self.column_states, "sort" ) if can_be_flattened: new = copy(self) From 8035696a19c40011d34813dd41be9491b1851328 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 31 Oct 2025 16:31:22 -0700 Subject: [PATCH 08/15] update --- .../_internal/analyzer/select_statement.py | 99 +++++++++++++------ src/snowflake/snowpark/session.py | 4 + 2 files changed, 72 insertions(+), 31 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index c0bf4ce207..ca0b677a78 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -2306,52 +2306,89 @@ def derive_column_states_from_subquery( return column_states -def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: +def _check_expressions_for_types( + expressions: Optional[List["Expression"]], + check_data_gen: bool = False, + check_window: bool = False, + check_aggregation: bool = False, +) -> bool: + """Efficiently check if expressions contain specific types in a single pass. + + Args: + expressions: List of expressions to check + check_data_gen: Check for data generator functions + check_window: Check for window functions + check_aggregation: Check for aggregation functions + + Returns: + True if any requested type is found + """ if expressions is None: return False + for exp in expressions: - if not context._is_snowpark_connect_compatible_mode and isinstance( - exp, WindowExpression - ): + if exp is None: + continue + + # Check window functions + if check_window and isinstance(exp, WindowExpression): return True - if isinstance(exp, FunctionExpression) and ( - exp.is_data_generator - or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION + + # Check data generators (including window in non-connect mode) + if check_data_gen: + # In non-connect mode, windows are treated as data generators + if not context._is_snowpark_connect_compatible_mode and isinstance( + exp, WindowExpression + ): + return True + # Check actual data generator functions + if isinstance(exp, FunctionExpression) and ( + exp.is_data_generator + or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION + ): + # https://docs.snowflake.com/en/sql-reference/functions-data-generation + return True + + # Check aggregation functions + if check_aggregation and isinstance(exp, FunctionExpression): + if exp.name.lower() in context._aggregation_function_set: + return True + + # Recursively check children + if _check_expressions_for_types( + exp.children, check_data_gen, check_window, check_aggregation ): - # https://docs.snowflake.com/en/sql-reference/functions-data-generation - return True - if exp is not None and has_data_generator_exp(exp.children): return True + return False -def has_window_function_exp(expressions: Optional[List["Expression"]]) -> bool: - if expressions is None: - return False - for exp in expressions: - if isinstance(exp, WindowExpression): - return True - if exp is not None and has_window_function_exp(exp.children): - return True - return False +def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: + """Check if expressions contain data generator functions. + + Note: + In non-connect mode, check_data_gen check both data generator and window expressions for backward compatibility. + In connect mode, check_data_gen only checks data generator expressions. + """ + return _check_expressions_for_types(expressions, check_data_gen=True) def has_data_generator_or_window_function_exp( expressions: Optional[List["Expression"]], ) -> bool: + """Check if expressions contain data generators or window functions. + + Optimized to do a single pass checking both types simultaneously. + """ if not context._is_snowpark_connect_compatible_mode: - return has_data_generator_exp(expressions) - return has_data_generator_exp(expressions) or has_window_function_exp(expressions) + # In non-connect mode, windows are already treated as data generators + return _check_expressions_for_types(expressions, check_data_gen=True) + # In connect mode, check both in a single pass + return _check_expressions_for_types( + expressions, check_data_gen=True, check_window=True + ) def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool: - if expressions is None: - return False - for exp in expressions: - if isinstance(exp, FunctionExpression) and ( - exp.name.lower() in context._aggregation_function_set - ): - return True - if exp is not None and has_aggregation_function_exp(exp.children): - return True - return False + """Check if expressions contain aggregation functions.""" + return _check_expressions_for_types(expressions, check_aggregation=True) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 6fea3308ae..85915c082b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -4881,6 +4881,10 @@ def _retrieve_aggregation_function_list(self) -> None: "Unable to get aggregation functions from the database: %s", e, ) + # we raise error here as a pessimistic tactics + # the reason is that if we fail to retrieve the aggregation function list, we have empty set + # the simplifier will flatten the query which contains aggregation functions leading to incorrect results + raise with context._aggregation_function_set_lock: context._aggregation_function_set.update(retrieved_set) From 660a5a3b730bfe4dddcf1a8ccf69dec043380dbb Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Tue, 7 Apr 2026 09:03:19 -0700 Subject: [PATCH 09/15] Re-sort for df.sort().select() when possible for Snowpark Connect --- .../_internal/analyzer/select_statement.py | 20 ++- tests/integ/test_simplifier_suite.py | 123 +++++++++++++++++- 2 files changed, 138 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index cd661d6b66..3de2e23acc 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1486,8 +1486,26 @@ def select(self, cols: List[Expression]) -> "SelectStatement": self.df_ast_ids.copy() if self.df_ast_ids is not None else None ) else: + new_order_by = None + if context._is_snowpark_connect_compatible_mode and self.order_by: + order_by_dependent_columns = derive_dependent_columns(*self.order_by) + if order_by_dependent_columns in ( + COLUMN_DEPENDENCY_DOLLAR, + COLUMN_DEPENDENCY_ALL, + ) or any( + _col not in self.column_states + or self.column_states[_col].change_state + in (ColumnChangeState.CHANGED_EXP, ColumnChangeState.DROPPED) + for _col in order_by_dependent_columns + ): + new_order_by = None + else: + new_order_by = self.order_by new = SelectStatement( - projection=cols, from_=self.to_subqueryable(), analyzer=self.analyzer + projection=cols, + from_=self.to_subqueryable(), + order_by=new_order_by, + analyzer=self.analyzer, ) new._merge_projection_complexity_with_subquery = ( can_select_projection_complexity_be_merged( diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 9c6ddbe296..327408292a 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -1863,14 +1863,14 @@ def test_select_after_filter( ( lambda df: df.order_by(col("A")).select(seq1(0)), 'SELECT seq1(0) FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST )', - None, + 'SELECT seq1(0) FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST ) ORDER BY "A" ASC NULLS FIRST', True, ), # Not flattened, unlike filter, current query takes precendence when there are duplicate column names from a ORDERBY clause ( lambda df: df.order_by(col("A")).select((col("B") + 1).alias("A")), 'SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST )', - None, + 'SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST ) ORDER BY "A" ASC NULLS FIRST', True, ), # Not flattened, since we cannot detect dependent columns from sql_expr @@ -1928,7 +1928,7 @@ def test_select_after_filter( .order_by(col("D")) .select(col("A").alias("E")), 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "D" ASC NULLS FIRST', - 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST)', + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST) ORDER BY "D" ASC NULLS FIRST', True, ), # Not flattened if a dropped new column is used in the order by clause's dependent columns @@ -1937,7 +1937,27 @@ def test_select_after_filter( .order_by(col("D") - 1) .select((col("A") + 1).alias("E")), 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST', - 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST)', + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST', + True, + ), + # Not flattened if a dropped new column is used in the order by clause, select without alias + ( + lambda df: df.select(col("A"), col("B").alias("D")) + .order_by(col("D")) + .select(col("A")), + 'SELECT "A" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "D" ASC NULLS FIRST', + 'SELECT "A" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST) ORDER BY "D" ASC NULLS FIRST', + True, + ), + # Not flattened with multiple order by columns and partial drop + ( + lambda df: df.select( + col("A"), col("B").alias("D"), (col("A") + col("B")).alias("E") + ) + .order_by(col("D"), col("E")) + .select(col("A")), + 'SELECT "A" FROM (SELECT "A", "B" AS "D", ("A" + "B") AS "E" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "D" ASC NULLS FIRST, "E" ASC NULLS FIRST', + 'SELECT "A" FROM (SELECT "A", "B" AS "D", ("A" + "B") AS "E" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST, "E" ASC NULLS FIRST) ORDER BY "D" ASC NULLS FIRST, "E" ASC NULLS FIRST', True, ), # Not flattened if a dropped column that was changed expression is used in the select clause's dependent columns @@ -1986,6 +2006,101 @@ def test_select_after_orderby( Utils.check_answer(operation(df1), operation(df2)) +@pytest.mark.parametrize("is_snowpark_connect_compatible_mode", [True, False]) +def test_order_by_preserved_after_non_flattened_select( + session, monkeypatch, is_snowpark_connect_compatible_mode +): + """When select() can't be flattened and drops columns used in order_by, + compat mode should preserve the order_by on the outer SelectStatement + so that subsequent operations like limit() maintain correct ordering. + In both modes, data ordering should be correct.""" + if is_snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + + df = session.create_dataframe( + [[1, 3], [2, 1], [3, 2]], + schema=["a", "b"], + ) + + # order_by uses a column (D) that gets dropped by the subsequent select + df1 = df.select(col("A"), col("B").alias("D")).order_by(col("D")).select(col("A")) + + assert "ORDER BY" in df1.queries["queries"][-1] + # B values [3,1,2] → sorted D [1,2,3] → A order is [2,3,1] + Utils.check_answer(df1, [Row(2), Row(3), Row(1)], sort=False) + + if is_snowpark_connect_compatible_mode: + # In compat mode, the order_by should be preserved on the outer statement + assert df1._select_statement.order_by is not None + + # Verify ordering survives through limit() + limited = df1.limit(2) + Utils.check_answer(limited, [Row(2), Row(3)], sort=False) + + # Verify with multiple order_by columns where some are dropped + df2 = ( + df.select(col("A"), col("B").alias("D"), (col("A") + col("B")).alias("E")) + .order_by(col("D"), col("E")) + .select(col("A")) + ) + Utils.check_answer(df2, [Row(2), Row(3), Row(1)], sort=False) + + # Verify the chained select().orderBy().select() pattern from the bug report: + # columns used in orderBy are dropped by the outer select + df3 = ( + df.select(col("A"), col("A").alias("S"), col("B").alias("SK")) + .order_by(col("A"), col("SK")) + .select("A", "S") + ) + assert "ORDER BY" in df3.queries["queries"][-1] + Utils.check_answer(df3, [Row(1, 1), Row(2, 2), Row(3, 3)], sort=False) + + # Verify order survives through limit on the chained pattern + limited3 = df3.limit(2) + Utils.check_answer(limited3, [Row(1, 1), Row(2, 2)], sort=False) + + +def test_order_by_outer_hoist_skipped_when_unsafe_compat_mode(session, monkeypatch): + """When select() is not flattened, compat mode only hoists order_by to the outer + SelectStatement when dependency is known and no ORDER BY column is CHANGED_EXP/DROPPED + (see select_statement non-flatten branch).""" + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + + df = session.create_dataframe( + [[1, 3], [2, 1], [3, 2]], + schema=["A", "B"], + ) + + # sql_expr ORDER BY: dependent columns are ALL/unknown -> do not hoist + df_sql = df.order_by(sql_expr("A")).select(col("B")) + assert df_sql._select_statement.order_by is None + assert "ORDER BY" in df_sql.queries["queries"][-1] + + # Positional $n ORDER BY: DOLLAR dependency -> do not hoist + df_dollar = df.order_by(col("$1")).select(col("B")) + assert df_dollar._select_statement.order_by is None + + # ORDER BY column is CHANGED_EXP in the inner projection -> do not hoist + df_changed = ( + df.select((col("A") * 2).alias("A"), col("B")) + .order_by(col("A")) + .select(col("B")) + ) + assert df_changed._select_statement.order_by is None + # Inner sort by computed A (2, 4, 6) -> B order 3, 1, 2 + Utils.check_answer(df_changed, [Row(3), Row(1), Row(2)], sort=False) + + # Contrast: ordering by a plain renamed column still hoists (safe path) + df_hoist = ( + df.select(col("A"), col("B").alias("D")).order_by(col("D")).select(col("A")) + ) + assert df_hoist._select_statement.order_by is not None + + def test_window_with_filter(session): df = session.create_dataframe([[0], [1]], schema=["A"]) df = ( From 70cabab74b79768cfc67cdbb643a782ade2e4e53 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Tue, 7 Apr 2026 11:01:19 -0700 Subject: [PATCH 10/15] simplify code --- .../_internal/analyzer/select_statement.py | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 3de2e23acc..1ee700ce9f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -2387,14 +2387,7 @@ def _check_expressions_for_types( if check_window and isinstance(exp, WindowExpression): return True - # Check data generators (including window in non-connect mode) if check_data_gen: - # In non-connect mode, windows are treated as data generators - if not context._is_snowpark_connect_compatible_mode and isinstance( - exp, WindowExpression - ): - return True - # Check actual data generator functions if isinstance(exp, FunctionExpression) and ( exp.is_data_generator or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION @@ -2420,23 +2413,20 @@ def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: """Check if expressions contain data generator functions. Note: - In non-connect mode, check_data_gen check both data generator and window expressions for backward compatibility. - In connect mode, check_data_gen only checks data generator expressions. + In non-connect mode, window expressions are also treated as data generators + for backward compatibility. """ + if not context._is_snowpark_connect_compatible_mode: + return _check_expressions_for_types( + expressions, check_data_gen=True, check_window=True + ) return _check_expressions_for_types(expressions, check_data_gen=True) def has_data_generator_or_window_function_exp( expressions: Optional[List["Expression"]], ) -> bool: - """Check if expressions contain data generators or window functions. - - Optimized to do a single pass checking both types simultaneously. - """ - if not context._is_snowpark_connect_compatible_mode: - # In non-connect mode, windows are already treated as data generators - return _check_expressions_for_types(expressions, check_data_gen=True) - # In connect mode, check both in a single pass + """Check if expressions contain data generators or window functions.""" return _check_expressions_for_types( expressions, check_data_gen=True, check_window=True ) From 7df91d6bb5996aba610f58a7ed8d9e8973ecaac5 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Mon, 13 Apr 2026 10:17:24 -0700 Subject: [PATCH 11/15] Fix select->sort->filter where sort has more columns than select --- .../_internal/analyzer/select_statement.py | 28 ++++++++++-- tests/integ/test_simplifier_suite.py | 45 +++++++++++++++---- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 3de2e23acc..7a84ecdfbb 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1487,23 +1487,45 @@ def select(self, cols: List[Expression]) -> "SelectStatement": ) else: new_order_by = None + new_from = self if context._is_snowpark_connect_compatible_mode and self.order_by: order_by_dependent_columns = derive_dependent_columns(*self.order_by) if order_by_dependent_columns in ( COLUMN_DEPENDENCY_DOLLAR, COLUMN_DEPENDENCY_ALL, - ) or any( + ): + new_order_by = None + elif any( + col not in self.from_.column_states + and col not in self.column_states + for col in order_by_dependent_columns + ): + new_order_by = None + elif any( _col not in self.column_states or self.column_states[_col].change_state in (ColumnChangeState.CHANGED_EXP, ColumnChangeState.DROPPED) for _col in order_by_dependent_columns ): - new_order_by = None + new_from = copy(self) + missing_columns = ( + order_by_dependent_columns + - new_from.column_states.active_columns + ) + new_from.projection = new_from.projection + [ + Attribute(col, DataType()) for col in missing_columns + ] + new_from.column_states = derive_column_states_from_subquery( + new_from.projection, new_from.from_ + ) + new_from._commented_sql = None + new_from._sql_query = None + new_order_by = self.order_by else: new_order_by = self.order_by new = SelectStatement( projection=cols, - from_=self.to_subqueryable(), + from_=new_from.to_subqueryable(), order_by=new_order_by, analyzer=self.analyzer, ) diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 327408292a..e1e4dfe2c2 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -2063,9 +2063,23 @@ def test_order_by_preserved_after_non_flattened_select( def test_order_by_outer_hoist_skipped_when_unsafe_compat_mode(session, monkeypatch): - """When select() is not flattened, compat mode only hoists order_by to the outer - SelectStatement when dependency is known and no ORDER BY column is CHANGED_EXP/DROPPED - (see select_statement non-flatten branch).""" + """Snowpark Connect compat mode: non-flattened ``select()`` may clear outer ``order_by`` + when ORDER BY deps are ALL/DOLLAR, or missing from both ``from_.column_states`` and + ``self.column_states`` (non-flatten path only). + + When the plan **flattens** (e.g. ``order_by(col("Z")).select(col("A"))`` with no ``Z`` + in the schema), ORDER BY is merged into a single ``SELECT ... ORDER BY "Z"`` and + ``_select_statement.order_by`` stays populated. + + To hit the non-flattened path where ``order_by`` deps are missing from both + ``from_.column_states`` and ``self.column_states`` (lines 1498-1499), use + ``order_by(col("Z")).select(seq1(0))``: ``seq1`` forces non-flatten, so outer + ``order_by`` is cleared in compat mode. + + If ORDER BY references CHANGED_EXP/DROPPED keys present in ``self.column_states``, + the inner projection may be augmented before wrapping; outer ``order_by`` is still + set so limit() and similar keep a deterministic sort (see select_statement.select). + """ import snowflake.snowpark.context as context monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) @@ -2075,26 +2089,41 @@ def test_order_by_outer_hoist_skipped_when_unsafe_compat_mode(session, monkeypat schema=["A", "B"], ) - # sql_expr ORDER BY: dependent columns are ALL/unknown -> do not hoist + # sql_expr ORDER BY: ALL/unknown deps -> do not set outer order_by df_sql = df.order_by(sql_expr("A")).select(col("B")) assert df_sql._select_statement.order_by is None assert "ORDER BY" in df_sql.queries["queries"][-1] - # Positional $n ORDER BY: DOLLAR dependency -> do not hoist + # Positional $n ORDER BY: DOLLAR -> do not set outer order_by df_dollar = df.order_by(col("$1")).select(col("B")) assert df_dollar._select_statement.order_by is None - # ORDER BY column is CHANGED_EXP in the inner projection -> do not hoist + # ORDER BY unknown column "Z" then select: typically flattens to one SELECT with + # ORDER BY "Z" (does not use the non-flatten compat outer order_by=None branch). + df_order_by_z = df.order_by(col("Z")).select(col("A")) + assert df_order_by_z._select_statement.order_by is not None + normalized_z = Utils.normalize_sql(df_order_by_z.queries["queries"][-1]) + assert "ORDER BY" in normalized_z + assert '"Z"' in normalized_z + + # Non-flatten: seq1() in select blocks flattening; ORDER BY "Z" with Z not in either + # column_states -> compat mode clears outer order_by (select_statement lines 1498-1499). + df_order_by_z_seq = df.order_by(col("Z")).select(seq1(0)) + assert df_order_by_z_seq._select_statement.order_by is None + assert "ORDER BY" in Utils.normalize_sql(df_order_by_z_seq.queries["queries"][-1]) + + # ORDER BY on CHANGED_EXP column: augment inner projection; outer still has order_by df_changed = ( df.select((col("A") * 2).alias("A"), col("B")) .order_by(col("A")) .select(col("B")) ) - assert df_changed._select_statement.order_by is None + assert df_changed._select_statement.order_by is not None # Inner sort by computed A (2, 4, 6) -> B order 3, 1, 2 Utils.check_answer(df_changed, [Row(3), Row(1), Row(2)], sort=False) - # Contrast: ordering by a plain renamed column still hoists (safe path) + # ORDER BY on alias only in this select (not in from_.column_states): still in + # self.column_states, so outer order_by is set (not the "missing from both" case) df_hoist = ( df.select(col("A"), col("B").alias("D")).order_by(col("D")).select(col("A")) ) From 7df994e1a244c8b493242eeb3d4f37b419029337 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Mon, 13 Apr 2026 13:28:17 -0700 Subject: [PATCH 12/15] improve agg function retrieval --- src/snowflake/snowpark/context.py | 108 ++++++++++++++++++++++++++++++ src/snowflake/snowpark/session.py | 48 ++++++++----- 2 files changed, 141 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index d4ae6cd0be..8329741835 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -36,6 +36,114 @@ ) # lower cased names of aggregation functions, used in sql simplification _aggregation_function_set_lock = threading.RLock() +# Hardcoded fallback for system built-in aggregation functions. +# Used when the dynamic query fails to retrieve the list from the database. +# +# Generated via: +# show functions ->> select "name" from $1 where "is_aggregate" = 'Y' +# +# Entries with parentheses in the name (COUNT(*), COUNT_INTERNAL(*)) are excluded +# because FunctionExpression.name stores only the function name without parens, +# so they can never match at the lookup site. +_KNOWN_AGGREGATION_FUNCTIONS = frozenset( + [ + "accumulate", + "agg", + "ai_agg", + "ai_summarize_agg", + "any_value", + "approximate_count_distinct", + "approximate_jaccard_index", + "approximate_similarity", + "approx_count_distinct", + "approx_percentile", + "approx_percentile_accumulate", + "approx_percentile_combine", + "approx_top_k", + "approx_top_k_accumulate", + "approx_top_k_combine", + "arrayagg", + "array_agg", + "array_union_agg", + "array_unique_agg", + "avg", + "bitandagg", + "bitand_agg", + "bitmap_construct_agg", + "bitmap_or_agg", + "bitoragg", + "bitor_agg", + "bitxoragg", + "bitxor_agg", + "bit_andagg", + "bit_and_agg", + "bit_oragg", + "bit_or_agg", + "bit_xoragg", + "bit_xor_agg", + "booland_agg", + "boolor_agg", + "boolxor_agg", + "corr", + "count", + "count_if", + "count_internal", + "covar_pop", + "covar_samp", + "datasketches_hll", + "datasketches_hll_accumulate", + "datasketches_hll_combine", + "first_value", + "hash_agg", + "hll", + "hll_accumulate", + "hll_combine", + "kurtosis", + "last_value", + "listagg", + "max", + "max_by", + "median", + "min", + "minhash", + "minhash_combine", + "min_by", + "mode", + "objectagg", + "object_agg", + "percentile_cont", + "percentile_disc", + "regr_avgx", + "regr_avgy", + "regr_count", + "regr_intercept", + "regr_r2", + "regr_slope", + "regr_sxx", + "regr_sxy", + "regr_syy", + "skew", + "stddev", + "stddev_pop", + "stddev_samp", + "st_intersection_agg_geography_internal", + "st_union_agg_geography_internal", + "sum", + "sum_internal", + "sum_internal_real", + "sum_real", + "variance", + "variance_pop", + "variance_samp", + "var_pop", + "var_samp", + "vector_avg", + "vector_max", + "vector_min", + "vector_sum", + ] +) + _cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session. # Following are internal-only global flags, used to enable development features. diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index d98f0c1401..18d8268ac4 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5044,21 +5044,39 @@ def _retrieve_aggregation_function_list(self) -> None: retrieved_set = set() - for sql in [ - """select function_name from information_schema.functions where is_aggregate = 'YES'""", - """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""", - ]: - try: - retrieved_set.update({r[0].lower() for r in self.sql(sql).collect()}) - except BaseException as e: - _logger.debug( - "Unable to get aggregation functions from the database: %s", - e, - ) - # we raise error here as a pessimistic tactics - # the reason is that if we fail to retrieve the aggregation function list, we have empty set - # the simplifier will flatten the query which contains aggregation functions leading to incorrect results - raise + # User-defined aggregation functions + try: + retrieved_set.update( + { + r[0].lower() + for r in self.sql( + """select function_name from information_schema.functions where is_aggregate = 'YES'""" + ).collect() + } + ) + except BaseException as e: + _logger.debug( + "Unable to get user-defined aggregation functions: %s", + e, + ) + + # System built-in aggregation functions + try: + retrieved_set.update( + { + r[0].lower() + for r in self.sql( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" + ).collect() + } + ) + except BaseException as e: + _logger.debug( + "Unable to get system aggregation functions, " + "falling back to hardcoded list: %s", + e, + ) + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) with context._aggregation_function_set_lock: context._aggregation_function_set.update(retrieved_set) From 7ba0dc0bc9c485d54d0d85986c212fce447e4072 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Mon, 13 Apr 2026 17:24:07 -0700 Subject: [PATCH 13/15] Revert "SNOW-3266242: Support TRY_CAST with user-provided schema in DataFrameReader (#4138)" This reverts commit 883803a7da25ce471dfbae8caad9b5a50b5474f6. --- CHANGELOG.md | 4 -- src/snowflake/snowpark/dataframe_reader.py | 45 +------------------ .../scala/test_dataframe_reader_suite.py | 40 ----------------- 3 files changed, 1 insertion(+), 88 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cddf0a0796..a9c5b71185 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,10 +16,6 @@ - Allow user input schema when reading Parquet file on stage. -#### Bug Fixes - -- Fixed a bug that `TRY_CAST` reader option is ignored when calling `DataFrameReader.schema().csv()`. - #### Improvements - Restored the following query improvements that were reverted in 1.47.0 due to bugs: diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 22bfbf0f15..aab601a1f9 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -798,9 +798,6 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: self._file_type = "CSV" schema_to_cast, transformations = None, None - # this parameter determine whether generate schema_to_cast and transformations when user schema exist - # schema_to_cast and transformations is needed to apply try_cast to data - user_schema_with_try_cast = False if not self._user_schema: if not self._infer_schema: @@ -837,13 +834,7 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: transformations = [] else: self._cur_options["INFER_SCHEMA"] = False - try_cast = self._cur_options.get("TRY_CAST", False) - ( - schema, - schema_to_cast, - transformations, - ) = self._get_schema_from_csv_user_input(self._user_schema, try_cast) - user_schema_with_try_cast = try_cast + schema = self._user_schema._to_attributes() metadata_project, metadata_schema = self._get_metadata_project_and_schema() @@ -869,7 +860,6 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: transformations=transformations, metadata_project=metadata_project, metadata_schema=metadata_schema, - use_user_schema=user_schema_with_try_cast, ), analyzer=self._session._analyzer, ), @@ -890,7 +880,6 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: transformations=transformations, metadata_project=metadata_project, metadata_schema=metadata_schema, - use_user_schema=user_schema_with_try_cast, ), _ast_stmt=stmt, _emit_ast=_emit_ast, @@ -1399,38 +1388,6 @@ def _infer_schema_for_file_format( return new_schema, schema_to_cast, read_file_transformations, None - def _get_schema_from_csv_user_input( - self, user_schema: StructType, try_cast: bool - ) -> Tuple[List, Optional[List], Optional[List]]: - """ - This function accept a user input structtype and return schemas needed for reading CSV file. - CSV files are processed differently than semi-structured file so need a different helper function. - """ - schema_to_cast = [] - transformations = [] - new_schema = [] - for index, field in enumerate(user_schema.fields, start=1): - new_schema.append( - Attribute( - field.column_identifier.quoted_name, - field.datatype, - field.nullable, - ) - ) - sf_type = convert_sp_to_sf_type(field.datatype) - # TODO: SNOW-3324409 Support relaxed schema when read csv in copy mode - if try_cast: - identifier = f"TRY_CAST(${index} AS {sf_type})" - schema_to_cast.append((identifier, field.name)) - transformations.append(sql_expr(identifier)) - - read_file_transformations = [t._expression.sql for t in transformations] - # schema_to_cast and read_file_transformations should only exist when try_cast is True - # this is meant to not break the current behavior - if not try_cast: - return new_schema, None, None - return new_schema, schema_to_cast, read_file_transformations - def _get_schema_from_user_input( self, user_schema: StructType ) -> Tuple[List, List, List]: diff --git a/tests/integ/scala/test_dataframe_reader_suite.py b/tests/integ/scala/test_dataframe_reader_suite.py index cc1dfd1cc9..be91633c10 100644 --- a/tests/integ/scala/test_dataframe_reader_suite.py +++ b/tests/integ/scala/test_dataframe_reader_suite.py @@ -375,46 +375,6 @@ def test_read_csv(session, mode): assert "is out of range" in str(ex_info.value) -@pytest.mark.xfail( - "config.getoption('local_testing_mode', default=False)", - reason="SNOW-1435112: csv infer schema option is not supported", - run=False, -) -@pytest.mark.parametrize("mode", ["select", "copy"]) -def test_read_csv_with_user_schema_try_cast(session, mode): - reader = get_reader(session, mode) - test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}" - try_cast_schema = StructType( - [ - StructField("A", LongType()), - StructField("B", StringType()), - StructField("C", DoubleType()), - ] - ) - df_try_cast = ( - reader.schema(try_cast_schema).option("TRY_CAST", True).csv(test_file_on_stage) - ) - try_cast_res = df_try_cast.collect() - try_cast_res.sort(key=lambda x: x[0]) - assert try_cast_res == [Row(A=1, B="one", C=1.2), Row(A=2, B="two", C=2.2)] - assert df_try_cast.schema == try_cast_schema - - try_cast_schema = StructType( - [ - StructField("A", LongType()), - StructField("B", LongType()), - StructField("C", DoubleType()), - ] - ) - df_try_cast = ( - reader.schema(try_cast_schema).option("TRY_CAST", True).csv(test_file_on_stage) - ) - try_cast_res = df_try_cast.collect() - try_cast_res.sort(key=lambda x: x[0]) - assert try_cast_res == [Row(A=1, B=None, C=1.2), Row(A=2, B=None, C=2.2)] - assert df_try_cast.schema == try_cast_schema - - @pytest.mark.xfail( "config.getoption('local_testing_mode', default=False)", reason="SNOW-1435112: csv infer schema option is not supported", From 82a556cf4771b50e86a642e52d01e60ba6146c77 Mon Sep 17 00:00:00 2001 From: Yijun Xie Date: Mon, 13 Apr 2026 23:29:58 -0700 Subject: [PATCH 14/15] Avoid error when select from --- .../_internal/analyzer/select_statement.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 7ef39b59a8..859df60443 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1541,12 +1541,18 @@ def select(self, cols: List[Expression]) -> "SelectStatement": new_from.projection = new_from.projection + [ Attribute(col, DataType()) for col in missing_columns ] - new_from.column_states = derive_column_states_from_subquery( + new_col_states = derive_column_states_from_subquery( new_from.projection, new_from.from_ ) - new_from._commented_sql = None - new_from._sql_query = None - new_order_by = self.order_by + if new_col_states is not None: + new_from.column_states = new_col_states + new_from._projection_in_str = None + new_from._commented_sql = None + new_from._sql_query = None + new_order_by = self.order_by + else: + new_from = self + new_order_by = None else: new_order_by = self.order_by new = SelectStatement( From b0c702b068fdb91a7a8543a4475d2f01544b519a Mon Sep 17 00:00:00 2001 From: May Liu Date: Tue, 14 Apr 2026 16:52:23 -0700 Subject: [PATCH 15/15] fix workload --- .../_internal/analyzer/select_statement.py | 47 ++++++++++--------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 859df60443..c67867e6ab 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1616,21 +1616,14 @@ def filter(self, col: Expression) -> "SelectStatement": def sort(self, cols: List[Expression]) -> "SelectStatement": can_be_flattened = ( (not self.flatten_disabled) - # limit order by and order by limit can cause big performance - # difference, because limit can stop table scanning whenever the - # number of record is satisfied. - # Therefore, disallow sql simplification when the - # current SelectStatement has a limit clause to avoid moving - # order by in front of limit. + # Disallow flattening when the current SelectStatement has a + # limit clause to avoid moving order by in front of limit. and (not self.limit_) and (not self.offset) and can_clause_dependent_columns_flatten( derive_dependent_columns(*cols), self.column_states, "sort" ) - and not has_data_generator_exp(self.projection) - # we do not check aggregation function here like filter - # in the case when aggregation function is in the projection - # order by is evaluated after aggregation, row info are not taken in the calculation + and not has_data_generator_or_window_function_exp(self.projection) ) if can_be_flattened: new = copy(self) @@ -2178,17 +2171,24 @@ def can_clause_dependent_columns_flatten( dc_state = subquery_column_states.get(dc) if dc_state: if dc_state.change_state == ColumnChangeState.CHANGED_EXP: - if ( - clause == "filter" - ): # where can not be flattened because 'where' is evaluated before projection, flattening leads to wrong result - # df.select((col('a') + 1).alias('a')).filter(col('a') > 5) -- this should be applied to the new 'a', flattening will use the old 'a' to evaluated + if clause == "filter": + return False + # sort + CHANGED_EXP: safe in SCOS mode since ORDER BY + # is evaluated after projection. Keep checking remaining + # columns though — another column may be unsafe. + elif not context._is_snowpark_connect_compatible_mode: return False - else: # clause == 'sort' - # df.select((col('a') + 1).alias('a')).sort(col('a')) -- this is valid to flatten because 'order by' is evaluated after projection - # however, if the order by is a data generator, it should not be flattened because generator is evaluated dynamically according to the order. - return context._is_snowpark_connect_compatible_mode elif dc_state.change_state == ColumnChangeState.NEW: - return context._is_snowpark_connect_compatible_mode + if clause == "sort" and dc_state.dependent_columns in ( + COLUMN_DEPENDENCY_DOLLAR, + COLUMN_DEPENDENCY_ALL, + ): + # Scalar subqueries in sort can trigger Snowflake + # internal errors when ORDER BY references them + # at the same SELECT level. + return False + if not context._is_snowpark_connect_compatible_mode: + return False return True @@ -2454,9 +2454,14 @@ def _check_expressions_for_types( if exp.name.lower() in context._aggregation_function_set: return True - # Recursively check children + # Recursively check children. + # Some expression types (e.g. CaseWhen) store sub-expressions in + # _child_expressions rather than children; fall back to that. + sub_exps = exp.children + if not sub_exps and hasattr(exp, "_child_expressions"): + sub_exps = exp._child_expressions if _check_expressions_for_types( - exp.children, check_data_gen, check_window, check_aggregation + sub_exps, check_data_gen, check_window, check_aggregation ): return True