diff --git a/CHANGELOG.md b/CHANGELOG.md index cddf0a0796..a9c5b71185 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,10 +16,6 @@ - Allow user input schema when reading Parquet file on stage. -#### Bug Fixes - -- Fixed a bug that `TRY_CAST` reader option is ignored when calling `DataFrameReader.schema().csv()`. - #### Improvements - Restored the following query improvements that were reverted in 1.47.0 due to bugs: diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index d45c6bf944..c67867e6ab 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -20,6 +20,7 @@ Sequence, Set, Union, + Literal, ) import snowflake.snowpark._internal.utils @@ -86,6 +87,7 @@ is_sql_select_statement, ExprAliasUpdateDict, ) +import snowflake.snowpark.context as context # Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable # Python 3.9 can use both @@ -1412,9 +1414,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement": ): # TODO: Clean up, this entire if case is parameter protection can_be_flattened = False - elif (self.where or self.order_by or self.limit_) and has_data_generator_exp( - cols - ): + elif ( + self.where or self.order_by or self.limit_ + ) and has_data_generator_or_window_function_exp(cols): can_be_flattened = False elif self.where and ( (subquery_dependent_columns := derive_dependent_columns(self.where)) @@ -1425,6 +1427,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement": subquery_dependent_columns & new_column_states.active_columns ) ) + or ( + # unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery + # reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error + # sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1' + context._is_snowpark_connect_compatible_mode + and new_column_states.dropped_columns + and any( + self.column_states[_col].change_state + in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) + for _col in ( + subquery_dependent_columns & new_column_states.dropped_columns + ) + ) + ) ): can_be_flattened = False elif self.order_by and ( @@ -1437,6 +1453,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement": subquery_dependent_columns & new_column_states.active_columns ) ) + or ( + # unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery + # reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error + # sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"' + context._is_snowpark_connect_compatible_mode + and new_column_states.dropped_columns + and any( + self.column_states[_col].change_state + in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP) + for _col in ( + subquery_dependent_columns & new_column_states.dropped_columns + ) + ) + ) ): can_be_flattened = False elif self.distinct_: @@ -1482,8 +1512,54 @@ def select(self, cols: List[Expression]) -> "SelectStatement": self.df_ast_ids.copy() if self.df_ast_ids is not None else None ) else: + new_order_by = None + new_from = self + if context._is_snowpark_connect_compatible_mode and self.order_by: + order_by_dependent_columns = derive_dependent_columns(*self.order_by) + if order_by_dependent_columns in ( + COLUMN_DEPENDENCY_DOLLAR, + COLUMN_DEPENDENCY_ALL, + ): + new_order_by = None + elif any( + col not in self.from_.column_states + and col not in self.column_states + for col in order_by_dependent_columns + ): + new_order_by = None + elif any( + _col not in self.column_states + or self.column_states[_col].change_state + in (ColumnChangeState.CHANGED_EXP, ColumnChangeState.DROPPED) + for _col in order_by_dependent_columns + ): + new_from = copy(self) + missing_columns = ( + order_by_dependent_columns + - new_from.column_states.active_columns + ) + new_from.projection = new_from.projection + [ + Attribute(col, DataType()) for col in missing_columns + ] + new_col_states = derive_column_states_from_subquery( + new_from.projection, new_from.from_ + ) + if new_col_states is not None: + new_from.column_states = new_col_states + new_from._projection_in_str = None + new_from._commented_sql = None + new_from._sql_query = None + new_order_by = self.order_by + else: + new_from = self + new_order_by = None + else: + new_order_by = self.order_by new = SelectStatement( - projection=cols, from_=self.to_subqueryable(), analyzer=self.analyzer + projection=cols, + from_=new_from.to_subqueryable(), + order_by=new_order_by, + analyzer=self.analyzer, ) new._merge_projection_complexity_with_subquery = ( can_select_projection_complexity_be_merged( @@ -1504,12 +1580,17 @@ def select(self, cols: List[Expression]) -> "SelectStatement": return new def filter(self, col: Expression) -> "SelectStatement": + self._session._retrieve_aggregation_function_list() can_be_flattened = ( (not self.flatten_disabled) and can_clause_dependent_columns_flatten( - derive_dependent_columns(col), self.column_states + derive_dependent_columns(col), self.column_states, "filter" ) - and not has_data_generator_exp(self.projection) + and not has_data_generator_or_window_function_exp(self.projection) + and not ( + context._is_snowpark_connect_compatible_mode + and has_aggregation_function_exp(self.projection) + ) # sum(col) as new_col, new_col can not be flattened in where clause and not (self.order_by and self.limit_ is not None) ) if can_be_flattened: @@ -1535,18 +1616,14 @@ def filter(self, col: Expression) -> "SelectStatement": def sort(self, cols: List[Expression]) -> "SelectStatement": can_be_flattened = ( (not self.flatten_disabled) - # limit order by and order by limit can cause big performance - # difference, because limit can stop table scanning whenever the - # number of record is satisfied. - # Therefore, disallow sql simplification when the - # current SelectStatement has a limit clause to avoid moving - # order by in front of limit. + # Disallow flattening when the current SelectStatement has a + # limit clause to avoid moving order by in front of limit. and (not self.limit_) and (not self.offset) and can_clause_dependent_columns_flatten( - derive_dependent_columns(*cols), self.column_states + derive_dependent_columns(*cols), self.column_states, "sort" ) - and not has_data_generator_exp(self.projection) + and not has_data_generator_or_window_function_exp(self.projection) ) if can_be_flattened: new = copy(self) @@ -1583,7 +1660,7 @@ def distinct(self) -> "SelectStatement": # .order_by(col1).select(col2).distinct() cannot be flattened because # SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL and (not (self.order_by and self.has_projection)) - and not has_data_generator_exp(self.projection) + and not has_data_generator_or_window_function_exp(self.projection) ) if can_be_flattened: new = copy(self) @@ -2074,7 +2151,12 @@ def can_projection_dependent_columns_be_flattened( def can_clause_dependent_columns_flatten( dependent_columns: Optional[AbstractSet[str]], subquery_column_states: ColumnStateDict, + clause: Literal["filter", "sort"], ) -> bool: + assert clause in ( + "filter", + "sort", + ), f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}" if dependent_columns == COLUMN_DEPENDENCY_DOLLAR: return False elif ( @@ -2089,15 +2171,25 @@ def can_clause_dependent_columns_flatten( dc_state = subquery_column_states.get(dc) if dc_state: if dc_state.change_state == ColumnChangeState.CHANGED_EXP: - return False + if clause == "filter": + return False + # sort + CHANGED_EXP: safe in SCOS mode since ORDER BY + # is evaluated after projection. Keep checking remaining + # columns though — another column may be unsafe. + elif not context._is_snowpark_connect_compatible_mode: + return False elif dc_state.change_state == ColumnChangeState.NEW: - # Most of the time this can be flattened. But if a new column uses window function and this column - # is used in a clause, the sql doesn't work in Snowflake. - # For instance `select a, rank() over(order by b) as d from test_table where d = 1` doesn't work. - # But `select a, b as d from test_table where d = 1` works - # We can inspect whether the referenced new column uses window function. Here we are being - # conservative for now to not flatten the SQL. - return False + if clause == "sort" and dc_state.dependent_columns in ( + COLUMN_DEPENDENCY_DOLLAR, + COLUMN_DEPENDENCY_ALL, + ): + # Scalar subqueries in sort can trigger Snowflake + # internal errors when ORDER BY references them + # at the same SELECT level. + return False + if not context._is_snowpark_connect_compatible_mode: + return False + return True @@ -2321,23 +2413,89 @@ def derive_column_states_from_subquery( return column_states -def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: +def _check_expressions_for_types( + expressions: Optional[List["Expression"]], + check_data_gen: bool = False, + check_window: bool = False, + check_aggregation: bool = False, +) -> bool: + """Efficiently check if expressions contain specific types in a single pass. + + Args: + expressions: List of expressions to check + check_data_gen: Check for data generator functions + check_window: Check for window functions + check_aggregation: Check for aggregation functions + + Returns: + True if any requested type is found + """ if expressions is None: return False + for exp in expressions: - if isinstance(exp, WindowExpression): + if exp is None: + continue + + # Check window functions + if check_window and isinstance(exp, WindowExpression): return True - if isinstance(exp, FunctionExpression) and ( - exp.is_data_generator - or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION + + if check_data_gen: + if isinstance(exp, FunctionExpression) and ( + exp.is_data_generator + or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION + ): + # https://docs.snowflake.com/en/sql-reference/functions-data-generation + return True + + # Check aggregation functions + if check_aggregation and isinstance(exp, FunctionExpression): + if exp.name.lower() in context._aggregation_function_set: + return True + + # Recursively check children. + # Some expression types (e.g. CaseWhen) store sub-expressions in + # _child_expressions rather than children; fall back to that. + sub_exps = exp.children + if not sub_exps and hasattr(exp, "_child_expressions"): + sub_exps = exp._child_expressions + if _check_expressions_for_types( + sub_exps, check_data_gen, check_window, check_aggregation ): - # https://docs.snowflake.com/en/sql-reference/functions-data-generation - return True - if exp is not None and has_data_generator_exp(exp.children): return True + return False +def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool: + """Check if expressions contain data generator functions. + + Note: + In non-connect mode, window expressions are also treated as data generators + for backward compatibility. + """ + if not context._is_snowpark_connect_compatible_mode: + return _check_expressions_for_types( + expressions, check_data_gen=True, check_window=True + ) + return _check_expressions_for_types(expressions, check_data_gen=True) + + +def has_data_generator_or_window_function_exp( + expressions: Optional[List["Expression"]], +) -> bool: + """Check if expressions contain data generators or window functions.""" + return _check_expressions_for_types( + expressions, check_data_gen=True, check_window=True + ) + + +def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool: + """Check if expressions contain aggregation functions.""" + return _check_expressions_for_types(expressions, check_aggregation=True) + + def has_nondeterministic_data_generation_exp( expressions: Optional[List["Expression"]], ) -> bool: diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index 2a0ad2fdf3..8329741835 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -31,6 +31,118 @@ # This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect _is_snowpark_connect_compatible_mode = False +_aggregation_function_set = ( + set() +) # lower cased names of aggregation functions, used in sql simplification +_aggregation_function_set_lock = threading.RLock() + +# Hardcoded fallback for system built-in aggregation functions. +# Used when the dynamic query fails to retrieve the list from the database. +# +# Generated via: +# show functions ->> select "name" from $1 where "is_aggregate" = 'Y' +# +# Entries with parentheses in the name (COUNT(*), COUNT_INTERNAL(*)) are excluded +# because FunctionExpression.name stores only the function name without parens, +# so they can never match at the lookup site. +_KNOWN_AGGREGATION_FUNCTIONS = frozenset( + [ + "accumulate", + "agg", + "ai_agg", + "ai_summarize_agg", + "any_value", + "approximate_count_distinct", + "approximate_jaccard_index", + "approximate_similarity", + "approx_count_distinct", + "approx_percentile", + "approx_percentile_accumulate", + "approx_percentile_combine", + "approx_top_k", + "approx_top_k_accumulate", + "approx_top_k_combine", + "arrayagg", + "array_agg", + "array_union_agg", + "array_unique_agg", + "avg", + "bitandagg", + "bitand_agg", + "bitmap_construct_agg", + "bitmap_or_agg", + "bitoragg", + "bitor_agg", + "bitxoragg", + "bitxor_agg", + "bit_andagg", + "bit_and_agg", + "bit_oragg", + "bit_or_agg", + "bit_xoragg", + "bit_xor_agg", + "booland_agg", + "boolor_agg", + "boolxor_agg", + "corr", + "count", + "count_if", + "count_internal", + "covar_pop", + "covar_samp", + "datasketches_hll", + "datasketches_hll_accumulate", + "datasketches_hll_combine", + "first_value", + "hash_agg", + "hll", + "hll_accumulate", + "hll_combine", + "kurtosis", + "last_value", + "listagg", + "max", + "max_by", + "median", + "min", + "minhash", + "minhash_combine", + "min_by", + "mode", + "objectagg", + "object_agg", + "percentile_cont", + "percentile_disc", + "regr_avgx", + "regr_avgy", + "regr_count", + "regr_intercept", + "regr_r2", + "regr_slope", + "regr_sxx", + "regr_sxy", + "regr_syy", + "skew", + "stddev", + "stddev_pop", + "stddev_samp", + "st_intersection_agg_geography_internal", + "st_union_agg_geography_internal", + "sum", + "sum_internal", + "sum_internal_real", + "sum_real", + "variance", + "variance_pop", + "variance_samp", + "var_pop", + "var_samp", + "vector_avg", + "vector_max", + "vector_min", + "vector_sum", + ] +) _cte_error_threshold = 3 # 0 to disable auto-cte-disable, otherwise the number of times CTE optimization can fail before it is automatically disabled for the remainder of the session. diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index 22bfbf0f15..aab601a1f9 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -798,9 +798,6 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: self._file_type = "CSV" schema_to_cast, transformations = None, None - # this parameter determine whether generate schema_to_cast and transformations when user schema exist - # schema_to_cast and transformations is needed to apply try_cast to data - user_schema_with_try_cast = False if not self._user_schema: if not self._infer_schema: @@ -837,13 +834,7 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: transformations = [] else: self._cur_options["INFER_SCHEMA"] = False - try_cast = self._cur_options.get("TRY_CAST", False) - ( - schema, - schema_to_cast, - transformations, - ) = self._get_schema_from_csv_user_input(self._user_schema, try_cast) - user_schema_with_try_cast = try_cast + schema = self._user_schema._to_attributes() metadata_project, metadata_schema = self._get_metadata_project_and_schema() @@ -869,7 +860,6 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: transformations=transformations, metadata_project=metadata_project, metadata_schema=metadata_schema, - use_user_schema=user_schema_with_try_cast, ), analyzer=self._session._analyzer, ), @@ -890,7 +880,6 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: transformations=transformations, metadata_project=metadata_project, metadata_schema=metadata_schema, - use_user_schema=user_schema_with_try_cast, ), _ast_stmt=stmt, _emit_ast=_emit_ast, @@ -1399,38 +1388,6 @@ def _infer_schema_for_file_format( return new_schema, schema_to_cast, read_file_transformations, None - def _get_schema_from_csv_user_input( - self, user_schema: StructType, try_cast: bool - ) -> Tuple[List, Optional[List], Optional[List]]: - """ - This function accept a user input structtype and return schemas needed for reading CSV file. - CSV files are processed differently than semi-structured file so need a different helper function. - """ - schema_to_cast = [] - transformations = [] - new_schema = [] - for index, field in enumerate(user_schema.fields, start=1): - new_schema.append( - Attribute( - field.column_identifier.quoted_name, - field.datatype, - field.nullable, - ) - ) - sf_type = convert_sp_to_sf_type(field.datatype) - # TODO: SNOW-3324409 Support relaxed schema when read csv in copy mode - if try_cast: - identifier = f"TRY_CAST(${index} AS {sf_type})" - schema_to_cast.append((identifier, field.name)) - transformations.append(sql_expr(identifier)) - - read_file_transformations = [t._expression.sql for t in transformations] - # schema_to_cast and read_file_transformations should only exist when try_cast is True - # this is meant to not break the current behavior - if not try_cast: - return new_schema, None, None - return new_schema, schema_to_cast, read_file_transformations - def _get_schema_from_user_input( self, user_schema: StructType ) -> Tuple[List, List, List]: diff --git a/src/snowflake/snowpark/mock/_select_statement.py b/src/snowflake/snowpark/mock/_select_statement.py index d21a86aeda..2a149db59e 100644 --- a/src/snowflake/snowpark/mock/_select_statement.py +++ b/src/snowflake/snowpark/mock/_select_statement.py @@ -412,7 +412,7 @@ def filter(self, col: Expression) -> "MockSelectStatement": else: dependent_columns = derive_dependent_columns(col) can_be_flattened = can_clause_dependent_columns_flatten( - dependent_columns, self.column_states + dependent_columns, self.column_states, "filter" ) if can_be_flattened: new = copy(self) @@ -433,7 +433,7 @@ def sort(self, cols: List[Expression]) -> "MockSelectStatement": else: dependent_columns = derive_dependent_columns(*cols) can_be_flattened = can_clause_dependent_columns_flatten( - dependent_columns, self.column_states + dependent_columns, self.column_states, "sort" ) if can_be_flattened: new = copy(self) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index e8f967c8c0..18d8268ac4 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -5034,6 +5034,53 @@ def _execute_sproc_internal( # Note the collect is implicit within the stored procedure call, so should not emit_ast here. return df.collect(statement_params=statement_params, _emit_ast=False)[0][0] + def _retrieve_aggregation_function_list(self) -> None: + """Retrieve the list of aggregation functions which will later be used in sql simplifier.""" + if ( + not context._is_snowpark_connect_compatible_mode + or context._aggregation_function_set + ): + return + + retrieved_set = set() + + # User-defined aggregation functions + try: + retrieved_set.update( + { + r[0].lower() + for r in self.sql( + """select function_name from information_schema.functions where is_aggregate = 'YES'""" + ).collect() + } + ) + except BaseException as e: + _logger.debug( + "Unable to get user-defined aggregation functions: %s", + e, + ) + + # System built-in aggregation functions + try: + retrieved_set.update( + { + r[0].lower() + for r in self.sql( + """show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""" + ).collect() + } + ) + except BaseException as e: + _logger.debug( + "Unable to get system aggregation functions, " + "falling back to hardcoded list: %s", + e, + ) + retrieved_set.update(context._KNOWN_AGGREGATION_FUNCTIONS) + + with context._aggregation_function_set_lock: + context._aggregation_function_set.update(retrieved_set) + def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame: """ Returns a DataFrame representing the results of a directory table query on the specified stage. diff --git a/tests/integ/scala/test_dataframe_reader_suite.py b/tests/integ/scala/test_dataframe_reader_suite.py index cc1dfd1cc9..be91633c10 100644 --- a/tests/integ/scala/test_dataframe_reader_suite.py +++ b/tests/integ/scala/test_dataframe_reader_suite.py @@ -375,46 +375,6 @@ def test_read_csv(session, mode): assert "is out of range" in str(ex_info.value) -@pytest.mark.xfail( - "config.getoption('local_testing_mode', default=False)", - reason="SNOW-1435112: csv infer schema option is not supported", - run=False, -) -@pytest.mark.parametrize("mode", ["select", "copy"]) -def test_read_csv_with_user_schema_try_cast(session, mode): - reader = get_reader(session, mode) - test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}" - try_cast_schema = StructType( - [ - StructField("A", LongType()), - StructField("B", StringType()), - StructField("C", DoubleType()), - ] - ) - df_try_cast = ( - reader.schema(try_cast_schema).option("TRY_CAST", True).csv(test_file_on_stage) - ) - try_cast_res = df_try_cast.collect() - try_cast_res.sort(key=lambda x: x[0]) - assert try_cast_res == [Row(A=1, B="one", C=1.2), Row(A=2, B="two", C=2.2)] - assert df_try_cast.schema == try_cast_schema - - try_cast_schema = StructType( - [ - StructField("A", LongType()), - StructField("B", LongType()), - StructField("C", DoubleType()), - ] - ) - df_try_cast = ( - reader.schema(try_cast_schema).option("TRY_CAST", True).csv(test_file_on_stage) - ) - try_cast_res = df_try_cast.collect() - try_cast_res.sort(key=lambda x: x[0]) - assert try_cast_res == [Row(A=1, B=None, C=1.2), Row(A=2, B=None, C=2.2)] - assert df_try_cast.schema == try_cast_schema - - @pytest.mark.xfail( "config.getoption('local_testing_mode', default=False)", reason="SNOW-1435112: csv infer schema option is not supported", diff --git a/tests/integ/test_query_line_intervals.py b/tests/integ/test_query_line_intervals.py index 439f0de71e..aedb0ef030 100644 --- a/tests/integ/test_query_line_intervals.py +++ b/tests/integ/test_query_line_intervals.py @@ -57,8 +57,9 @@ def generate_test_data(session, sql_simplifier_enabled): } +@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False]) @pytest.mark.parametrize( - "op,sql_simplifier,line_to_expected_sql", + "op,sql_simplifier,line_to_expected_sql,snowpark_connect_compatible_mode_sql", [ ( lambda data: data["df1"].union(data["df2"]), @@ -68,12 +69,16 @@ def generate_test_data(session, sql_simplifier_enabled): 6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', 10: 'SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (3 :: INT, \'C\' :: STRING, 300 :: INT), (4 :: INT, \'D\' :: STRING, 400 :: INT) )', }, + None, ), ( lambda data: data["df1"].filter(data["df1"].value > 150), True, { - 8: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', + 8: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)' + }, + { + 8: """SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM (SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, 'A' :: STRING, 100 :: INT), (2 :: INT, 'B' :: STRING, 200 :: INT)) WHERE ("VALUE" > 150)""", }, ), ( @@ -83,6 +88,7 @@ def generate_test_data(session, sql_simplifier_enabled): 1: 'SELECT "_1" AS "ID", "_2" AS "NAME" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) )', 4: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', }, + None, ), ( lambda data: data["df1"].pivot(F.col("name")).sum(F.col("value")), @@ -92,12 +98,26 @@ def generate_test_data(session, sql_simplifier_enabled): 6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)', 9: 'SELECT * FROM ( SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) ) ) PIVOT ( sum("VALUE") FOR "NAME" IN ( ANY ) )', }, + None, ), ], ) def test_get_plan_from_line_numbers_sql_content( - session, op, sql_simplifier, line_to_expected_sql + session, + op, + sql_simplifier, + line_to_expected_sql, + snowpark_connect_compatible_mode_sql, + snowpark_connect_compatible_mode, + monkeypatch, ): + if snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + line_to_expected_sql = ( + snowpark_connect_compatible_mode_sql or line_to_expected_sql + ) session.sql_simplifier_enabled = sql_simplifier df = op(generate_test_data(session, sql_simplifier)) diff --git a/tests/integ/test_simplifier_suite.py b/tests/integ/test_simplifier_suite.py index 68804657aa..e1e4dfe2c2 100644 --- a/tests/integ/test_simplifier_suite.py +++ b/tests/integ/test_simplifier_suite.py @@ -10,7 +10,7 @@ import pytest -from snowflake.snowpark import Row +from snowflake.snowpark import Row, Window from snowflake.snowpark._internal.analyzer.select_statement import ( SET_EXCEPT, SET_INTERSECT, @@ -31,6 +31,7 @@ sum as sum_, table_function, udtf, + rank, ) from tests.utils import TestData, Utils @@ -737,7 +738,19 @@ def test_reference_non_exist_columns(session, simplifier_table): df.select(col("c") + 1).collect() -def test_order_by(setup_reduce_cast, session, simplifier_table): +@pytest.mark.parametrize("is_snowpark_connect_compatible_mode", [True, False]) +def test_order_by( + setup_reduce_cast, + session, + simplifier_table, + is_snowpark_connect_compatible_mode, + monkeypatch, +): + if is_snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + df = session.table(simplifier_table) # flatten @@ -755,26 +768,80 @@ def test_order_by(setup_reduce_cast, session, simplifier_table): f'SELECT "A", "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) - # no flatten because c is a new column + # snowpark connect compatible mode: flatten if a new column is used in the order by clause + # snowflake mode: no flatten because c is a new column df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).sort("a", "b", "c") + compare_sql = ( + f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' + ) assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST, "C" ASC NULLS FIRST' + compare_sql ) - # no flatten because a and be are changed + # snowpark connect compatible mode: flatten even if a is changed because it's used in the order by clause + # snowflake mode: no flatten because a and be are changed df4 = df.select((col("a") + 1).as_("a"), ((col("b") + 1).as_("b"))).sort("a", "b") + compare_sql = ( + f'SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + compare_sql ) - # subquery has sql text so unable to figure out same-level dependency, so assuming d depends on c. No flatten. - df5 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).sort("a", "b") + # snowpark connect compatible mode: flatten if a window function is used in the projection + # snowflake mode: no flatten because c is a new column + df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).sort( + "a", "b" + ) + compare_sql = ( + f'SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql + ) + + # No flatten if a data generator is used in the projection + df6 = df.select("a", "b", seq1().alias("c")).sort("a", "b") + assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", "B", seq1(0) AS "C" FROM {simplifier_table}) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' + ) + + # subquery has sql text so unable to figure out if a data generator is used in the projection. No flatten. + df7 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).sort("a", "b") + assert Utils.normalize_sql(df7.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) ORDER BY "A" ASC NULLS FIRST, "B" ASC NULLS FIRST' ) + df8 = df.select("a", sum_(col("b")).alias("c")).sort("c") + compare_sql = ( + f'SELECT "A", sum("B") AS "C" FROM {simplifier_table} ORDER BY "C" ASC NULLS FIRST' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", sum("B") AS "C" FROM {simplifier_table} ) ORDER BY "C" ASC NULLS FIRST' + ) + assert Utils.normalize_sql(df8.queries["queries"][-1]) == Utils.normalize_sql( + compare_sql + ) + + +@pytest.mark.parametrize("is_snowpark_connect_compatible_mode", [True, False]) +def test_filter( + setup_reduce_cast, + session, + simplifier_table, + is_snowpark_connect_compatible_mode, + monkeypatch, +): + if is_snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) -def test_filter(setup_reduce_cast, session, simplifier_table): df = session.table(simplifier_table) integer_literal_postfix = ( "" if session.eliminate_numeric_sql_value_cast_enabled else " :: INT" @@ -792,35 +859,62 @@ def test_filter(setup_reduce_cast, session, simplifier_table): f'SELECT "A", "B" FROM {simplifier_table} WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' ) - # no flatten because c is a new column + # flatten if a regular new column is in the projection df3 = df.select("a", "b", (col("a") - col("b")).as_("c")).filter( (col("a") > 1) & (col("b") > 2) & (col("c") < 1) ) + compare_sql = ( + f'SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + if is_snowpark_connect_compatible_mode + else f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + ) assert Utils.normalize_sql(df3.queries["queries"][-1]) == Utils.normalize_sql( - f'SELECT * FROM ( SELECT "A", "B", ("A" - "B") AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + compare_sql + ) + + # no flatten if a window function is used in the projection + df5 = df.select("a", "b", rank().over(Window.order_by("b")).alias("c")).filter( + (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + ) + assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", "B", rank() OVER (ORDER BY "B" ASC NULLS FIRST) AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' + ) + + # no flatten if a data generator is used in the projection + df6 = df.select("a", "b", seq1().alias("c")).filter( + (col("a") > 1) & (col("b") > 2) & (col("c") < 1) + ) + assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", "B", seq1(0) AS "C" FROM {simplifier_table} ) WHERE ((("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix})) AND ("C" < 1{integer_literal_postfix}))' ) # no flatten because a and be are changed - df4 = df.select((col("a") + 1).as_("a"), (col("b") + 1).as_("b")).filter( + df7 = df.select((col("a") + 1).as_("a"), (col("b") + 1).as_("b")).filter( (col("a") > 1) & (col("b") > 2) ) - assert Utils.normalize_sql(df4.queries["queries"][-1]) == Utils.normalize_sql( + assert Utils.normalize_sql(df7.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT * FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' ) - df5 = df4.select("a") - assert Utils.normalize_sql(df5.queries["queries"][-1]) == Utils.normalize_sql( + df8 = df7.select("a") + assert Utils.normalize_sql(df8.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT "A" FROM ( SELECT ("A" + 1{integer_literal_postfix}) AS "A", ("B" + 1{integer_literal_postfix}) AS "B" FROM {simplifier_table} ) WHERE (("A" > 1{integer_literal_postfix}) AND ("B" > 2{integer_literal_postfix}))' ) # subquery has sql text so unable to figure out same-level dependency, so assuming d depends on c. No flatten. - df6 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).filter( + df9 = df.select("a", "b", lit(3).as_("c"), sql_expr("1 + 1 as d")).filter( col("a") > 1 ) - assert Utils.normalize_sql(df6.queries["queries"][-1]) == Utils.normalize_sql( + assert Utils.normalize_sql(df9.queries["queries"][-1]) == Utils.normalize_sql( f'SELECT * FROM ( SELECT "A", "B", 3 :: INT AS "C", 1 + 1 as d FROM ( SELECT * FROM {simplifier_table} ) ) WHERE ("A" > 1{integer_literal_postfix})' ) + # no flatten if a aggregation function is used in the projection + df10 = df.select("a", sum_(col("b")).alias("c")).filter(col("c") < 1) + assert Utils.normalize_sql(df10.queries["queries"][-1]) == Utils.normalize_sql( + f'SELECT * FROM ( SELECT "A", sum("B") AS "C" FROM {simplifier_table} ) WHERE ("C" < 1{integer_literal_postfix})' + ) + def test_limit(setup_reduce_cast, session, simplifier_table): df = session.table(simplifier_table) @@ -1620,18 +1714,21 @@ def test_chained_sort(session): ) +@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False]) @pytest.mark.parametrize( - "operation,simplified_query", + "operation,simplified_query,snowpark_connect_simplified_query", [ # Flattened ( lambda df: df.filter(col("A") > 1).select(col("B") + 1), 'SELECT ("B" + 1{POSTFIX}) FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX})', + None, ), # Flattened, if there are duplicate column names across the parent/child, WHERE is evaluated on subquery first, so we could flatten in this case ( lambda df: df.filter(col("A") > 1).select((col("B") + 1).alias("A")), 'SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX})', + None, ), # Flattened ( @@ -1639,6 +1736,27 @@ def test_chained_sort(session): .select(col("A"), col("B"), lit(12).alias("TWELVE")) .filter(col("A") > 2), 'SELECT "A", "B", 12 :: INT AS "TWELVE" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE (("A" > 1{POSTFIX}) AND ("A" > 2{POSTFIX}))', + None, + ), + # Flattened if the dropped columns are not used in filter in snowpark connect compatible mode + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A").alias("C"), col("B").alias("D")) + .filter(col("C") > 2) + .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE ("C" > 2{POSTFIX})', + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("C" > 2{POSTFIX}))', + ), + # Flattened if the dropped columns are not in the filter clause's dependent columns in snowpark connect compatible mode + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A").alias("C"), col("B").alias("D")) + .filter((col("C") + 1) > 2) + .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE (("C" + 1{POSTFIX}) > 2{POSTFIX})', + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("C" + 1{POSTFIX}) > 2{POSTFIX}))', ), # Not fully flattened, since col("A") > 1 and col("A") > 2 are referring to different columns ( @@ -1646,25 +1764,71 @@ def test_chained_sort(session): .select((col("B") + 1).alias("A")) .filter(col("A") > 2), 'SELECT * FROM ( SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX}) ) WHERE ("A" > 2{POSTFIX})', + None, ), # Not flattened, since A is updated in the select after filter. ( lambda df: df.filter(col("A") > 1).select("A", seq1(0)), 'SELECT "A", seq1(0) FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE ("A" > 1{POSTFIX}) )', + None, ), # Not flattened, since we cannot detect dependent columns from sql_expr ( lambda df: df.filter(sql_expr("A > 1")).select(col("B"), col("A")), 'SELECT "B", "A" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) WHERE A > 1 )', + None, ), # Not flattened, since we cannot flatten when the subquery uses positional parameter ($1) ( lambda df: df.filter(col("$1") > 1).select(col("B"), col("A")), 'SELECT "B", "A" FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ) WHERE ("$1" > 1{POSTFIX}) )', + None, + ), + # Not flattened if a dropped column is used in the filter clause + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A"), col("B").alias("D")) + .filter(col("D") > -3) + .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE ("D" > -3{POSTFIX})', + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND ("D" > -3{POSTFIX})))', + ), + # Not flattened if a dropped column is used in the select clause's dependent columns + # Notice the local inner flattening happening to WHERE clauses because in snowpark connect compatible mode, NEW column "D" can be flattened into the new query + ( + lambda df: df.filter(col("A") >= 1) + .select(col("A"), col("B").alias("D")) + .filter((col("D") - 1) > -4) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE ("A" >= 1{POSTFIX})) WHERE (("D" - 1{POSTFIX}) > -4{POSTFIX})', + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) WHERE (("A" >= 1{POSTFIX}) AND (("D" - 1{POSTFIX}) > -4{POSTFIX})))', + ), + # Not flattened if a dropped column that was changed expression is used in the select clause's dependent columns + ( + lambda df: df.select(col("A"), (col("B") + 1).alias("B")) + .filter((col("B") - 1) > -4) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + None, ), ], ) -def test_select_after_filter(setup_reduce_cast, session, operation, simplified_query): +def test_select_after_filter( + setup_reduce_cast, + session, + operation, + simplified_query, + snowpark_connect_compatible_mode, + monkeypatch, + snowpark_connect_simplified_query, +): + if snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + simplified_query = snowpark_connect_simplified_query or simplified_query + session.sql_simplifier_enabled = False df1 = session.create_dataframe([[1, -2], [3, -4]], schema=["a", "b"]) @@ -1684,43 +1848,50 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q ) == Utils.normalize_sql(simplified_query) +@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False]) @pytest.mark.parametrize( - "operation,simplified_query,execute_sql", + "operation,simplified_query,snowpark_connect_simplified_query,execute_sql", [ # Flattened ( lambda df: df.order_by(col("A")).select(col("B") + 1), 'SELECT ("B" + 1{POSTFIX}) FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST', + None, True, ), # Not flattened because SEQ1() is a data generator. ( lambda df: df.order_by(col("A")).select(seq1(0)), 'SELECT seq1(0) FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST )', + 'SELECT seq1(0) FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST ) ORDER BY "A" ASC NULLS FIRST', True, ), # Not flattened, unlike filter, current query takes precendence when there are duplicate column names from a ORDERBY clause ( lambda df: df.order_by(col("A")).select((col("B") + 1).alias("A")), 'SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST )', + 'SELECT ("B" + 1{POSTFIX}) AS "A" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "A" ASC NULLS FIRST ) ORDER BY "A" ASC NULLS FIRST', True, ), # Not flattened, since we cannot detect dependent columns from sql_expr ( lambda df: df.order_by(sql_expr("A")).select(col("B")), 'SELECT "B" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY A ASC NULLS FIRST )', + None, True, ), # Not flattened, since we cannot flatten when the subquery uses positional parameter ($1) ( lambda df: df.order_by(col("$1")).select(col("B")), 'SELECT "B" FROM ( SELECT * FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ) ORDER BY "$1" ASC NULLS FIRST )', + None, True, ), # Not flattened, skip execution since this would result in SnowparkSQLException ( lambda df: df.order_by(col("C")).select((col("A") + col("B")).alias("C")), 'SELECT ("A" + "B") AS "C" FROM ( SELECT "A", "B" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "C" ASC NULLS FIRST )', + None, False, ), # Flattened @@ -1730,13 +1901,92 @@ def test_select_after_filter(setup_reduce_cast, session, operation, simplified_q .order_by(col("B")) .select(col("A")), 'SELECT "A" FROM ( SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT) ) ORDER BY "B" ASC NULLS FIRST, "A" ASC NULLS FIRST', + None, + True, + ), + # Flattened if the dropped columns are not used in filter in the snowpark connect compatible mode + ( + lambda df: df.select(col("A").alias("C"), col("B").alias("D")) + .order_by(col("C")) + .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "C" ASC NULLS FIRST', + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "C" ASC NULLS FIRST', + True, + ), + # Flattened if the dropped columns are not in the order by clause's dependent columns + ( + lambda df: df.select(col("A").alias("C"), col("B").alias("D")) + .order_by(col("C") + 1) + .select(col("C")), + 'SELECT "C" FROM (SELECT "A" AS "C", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY ("C" + 1{POSTFIX}) ASC NULLS FIRST', + 'SELECT "A" AS "C" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("C" + 1{POSTFIX}) ASC NULLS FIRST', + True, + ), + # Not flattened if a dropped new column is used in the order by clause + ( + lambda df: df.select(col("A"), col("B").alias("D")) + .order_by(col("D")) + .select(col("A").alias("E")), + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "D" ASC NULLS FIRST', + 'SELECT "A" AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST) ORDER BY "D" ASC NULLS FIRST', + True, + ), + # Not flattened if a dropped new column is used in the order by clause's dependent columns + ( + lambda df: df.select(col("A"), col("B").alias("D")) + .order_by(col("D") - 1) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST', + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST) ORDER BY ("D" - 1{POSTFIX}) ASC NULLS FIRST', + True, + ), + # Not flattened if a dropped new column is used in the order by clause, select without alias + ( + lambda df: df.select(col("A"), col("B").alias("D")) + .order_by(col("D")) + .select(col("A")), + 'SELECT "A" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "D" ASC NULLS FIRST', + 'SELECT "A" FROM (SELECT "A", "B" AS "D" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST) ORDER BY "D" ASC NULLS FIRST', + True, + ), + # Not flattened with multiple order by columns and partial drop + ( + lambda df: df.select( + col("A"), col("B").alias("D"), (col("A") + col("B")).alias("E") + ) + .order_by(col("D"), col("E")) + .select(col("A")), + 'SELECT "A" FROM (SELECT "A", "B" AS "D", ("A" + "B") AS "E" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) ORDER BY "D" ASC NULLS FIRST, "E" ASC NULLS FIRST', + 'SELECT "A" FROM (SELECT "A", "B" AS "D", ("A" + "B") AS "E" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT)) ORDER BY "D" ASC NULLS FIRST, "E" ASC NULLS FIRST) ORDER BY "D" ASC NULLS FIRST, "E" ASC NULLS FIRST', + True, + ), + # Not flattened if a dropped column that was changed expression is used in the select clause's dependent columns + ( + lambda df: df.select(col("A"), (col("B") + 1).alias("B")) + .filter((col("B") - 1) > -4) + .select((col("A") + 1).alias("E")), + 'SELECT ("A" + 1{POSTFIX}) AS "E" FROM (SELECT "A", ("B" + 1{POSTFIX}) AS "B" FROM (SELECT $1 AS "A", $2 AS "B" FROM VALUES (1 :: INT, -2 :: INT), (3 :: INT, -4 :: INT))) WHERE (("B" - 1{POSTFIX}) > -4{POSTFIX})', + None, True, ), ], ) def test_select_after_orderby( - setup_reduce_cast, session, operation, simplified_query, execute_sql + setup_reduce_cast, + session, + operation, + simplified_query, + execute_sql, + snowpark_connect_compatible_mode, + monkeypatch, + snowpark_connect_simplified_query, ): + if snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + simplified_query = snowpark_connect_simplified_query or simplified_query + session.sql_simplifier_enabled = False df1 = session.create_dataframe([[1, -2], [3, -4]], schema=["a", "b"]) @@ -1756,6 +2006,130 @@ def test_select_after_orderby( Utils.check_answer(operation(df1), operation(df2)) +@pytest.mark.parametrize("is_snowpark_connect_compatible_mode", [True, False]) +def test_order_by_preserved_after_non_flattened_select( + session, monkeypatch, is_snowpark_connect_compatible_mode +): + """When select() can't be flattened and drops columns used in order_by, + compat mode should preserve the order_by on the outer SelectStatement + so that subsequent operations like limit() maintain correct ordering. + In both modes, data ordering should be correct.""" + if is_snowpark_connect_compatible_mode: + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + + df = session.create_dataframe( + [[1, 3], [2, 1], [3, 2]], + schema=["a", "b"], + ) + + # order_by uses a column (D) that gets dropped by the subsequent select + df1 = df.select(col("A"), col("B").alias("D")).order_by(col("D")).select(col("A")) + + assert "ORDER BY" in df1.queries["queries"][-1] + # B values [3,1,2] → sorted D [1,2,3] → A order is [2,3,1] + Utils.check_answer(df1, [Row(2), Row(3), Row(1)], sort=False) + + if is_snowpark_connect_compatible_mode: + # In compat mode, the order_by should be preserved on the outer statement + assert df1._select_statement.order_by is not None + + # Verify ordering survives through limit() + limited = df1.limit(2) + Utils.check_answer(limited, [Row(2), Row(3)], sort=False) + + # Verify with multiple order_by columns where some are dropped + df2 = ( + df.select(col("A"), col("B").alias("D"), (col("A") + col("B")).alias("E")) + .order_by(col("D"), col("E")) + .select(col("A")) + ) + Utils.check_answer(df2, [Row(2), Row(3), Row(1)], sort=False) + + # Verify the chained select().orderBy().select() pattern from the bug report: + # columns used in orderBy are dropped by the outer select + df3 = ( + df.select(col("A"), col("A").alias("S"), col("B").alias("SK")) + .order_by(col("A"), col("SK")) + .select("A", "S") + ) + assert "ORDER BY" in df3.queries["queries"][-1] + Utils.check_answer(df3, [Row(1, 1), Row(2, 2), Row(3, 3)], sort=False) + + # Verify order survives through limit on the chained pattern + limited3 = df3.limit(2) + Utils.check_answer(limited3, [Row(1, 1), Row(2, 2)], sort=False) + + +def test_order_by_outer_hoist_skipped_when_unsafe_compat_mode(session, monkeypatch): + """Snowpark Connect compat mode: non-flattened ``select()`` may clear outer ``order_by`` + when ORDER BY deps are ALL/DOLLAR, or missing from both ``from_.column_states`` and + ``self.column_states`` (non-flatten path only). + + When the plan **flattens** (e.g. ``order_by(col("Z")).select(col("A"))`` with no ``Z`` + in the schema), ORDER BY is merged into a single ``SELECT ... ORDER BY "Z"`` and + ``_select_statement.order_by`` stays populated. + + To hit the non-flattened path where ``order_by`` deps are missing from both + ``from_.column_states`` and ``self.column_states`` (lines 1498-1499), use + ``order_by(col("Z")).select(seq1(0))``: ``seq1`` forces non-flatten, so outer + ``order_by`` is cleared in compat mode. + + If ORDER BY references CHANGED_EXP/DROPPED keys present in ``self.column_states``, + the inner projection may be augmented before wrapping; outer ``order_by`` is still + set so limit() and similar keep a deterministic sort (see select_statement.select). + """ + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + + df = session.create_dataframe( + [[1, 3], [2, 1], [3, 2]], + schema=["A", "B"], + ) + + # sql_expr ORDER BY: ALL/unknown deps -> do not set outer order_by + df_sql = df.order_by(sql_expr("A")).select(col("B")) + assert df_sql._select_statement.order_by is None + assert "ORDER BY" in df_sql.queries["queries"][-1] + + # Positional $n ORDER BY: DOLLAR -> do not set outer order_by + df_dollar = df.order_by(col("$1")).select(col("B")) + assert df_dollar._select_statement.order_by is None + + # ORDER BY unknown column "Z" then select: typically flattens to one SELECT with + # ORDER BY "Z" (does not use the non-flatten compat outer order_by=None branch). + df_order_by_z = df.order_by(col("Z")).select(col("A")) + assert df_order_by_z._select_statement.order_by is not None + normalized_z = Utils.normalize_sql(df_order_by_z.queries["queries"][-1]) + assert "ORDER BY" in normalized_z + assert '"Z"' in normalized_z + + # Non-flatten: seq1() in select blocks flattening; ORDER BY "Z" with Z not in either + # column_states -> compat mode clears outer order_by (select_statement lines 1498-1499). + df_order_by_z_seq = df.order_by(col("Z")).select(seq1(0)) + assert df_order_by_z_seq._select_statement.order_by is None + assert "ORDER BY" in Utils.normalize_sql(df_order_by_z_seq.queries["queries"][-1]) + + # ORDER BY on CHANGED_EXP column: augment inner projection; outer still has order_by + df_changed = ( + df.select((col("A") * 2).alias("A"), col("B")) + .order_by(col("A")) + .select(col("B")) + ) + assert df_changed._select_statement.order_by is not None + # Inner sort by computed A (2, 4, 6) -> B order 3, 1, 2 + Utils.check_answer(df_changed, [Row(3), Row(1), Row(2)], sort=False) + + # ORDER BY on alias only in this select (not in from_.column_states): still in + # self.column_states, so outer order_by is set (not the "missing from both" case) + df_hoist = ( + df.select(col("A"), col("B").alias("D")).order_by(col("D")).select(col("A")) + ) + assert df_hoist._select_statement.order_by is not None + + def test_window_with_filter(session): df = session.create_dataframe([[0], [1]], schema=["A"]) df = ( @@ -1950,3 +2324,19 @@ def test_select_distinct( ) finally: session.conf.set("use_simplified_query_generation", original) + + +def test_retrieving_aggregation_funcs(session, monkeypatch): + import snowflake.snowpark.context as context + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + assert not context._aggregation_function_set + session._retrieve_aggregation_function_list() + assert context._aggregation_function_set + + monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", False) + monkeypatch.setattr(context, "_aggregation_function_set", set()) + assert not context._aggregation_function_set + session._retrieve_aggregation_function_list() + assert not context._aggregation_function_set