Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b694e2f
Loosen flattening rules for sort and filter
sfc-gh-yixie Oct 23, 2025
a25942b
No flatten if dropped columns after order by / sort
sfc-gh-yixie Oct 27, 2025
1df1e96
Add tests in simplifier
sfc-gh-yixie Oct 28, 2025
1104361
Update test after flattening
sfc-gh-yixie Oct 28, 2025
d67f5df
parameter protection and agg function check for fitler
sfc-gh-aling Oct 30, 2025
43c3316
fix tests
sfc-gh-aling Oct 31, 2025
3f7a98e
fix local testing
sfc-gh-aling Oct 31, 2025
8035696
update
sfc-gh-aling Oct 31, 2025
7cb8af7
Merge branch 'main' into yixie-SNOW-2203826-flatten-filter-sort-new
sfc-gh-yixie Dec 6, 2025
c00e140
Merge branch 'main' into yixie-SNOW-2203826-flatten-filter-sort-new
sfc-gh-yixie Dec 9, 2025
4c75cb2
Merge remote-tracking branch 'origin/main' into yixie-SNOW-2203826-fl…
sfc-gh-mayliu Mar 26, 2026
7de83f4
Merge branch 'yixie-SNOW-2203826-flatten-filter-sort-new' of github.c…
sfc-gh-yixie Mar 30, 2026
660a5a3
Re-sort for df.sort().select() when possible for Snowpark Connect
sfc-gh-yixie Apr 7, 2026
3b9fb2c
Merge branch 'main' of github.com:snowflakedb/snowpark-python into yi…
sfc-gh-yixie Apr 7, 2026
70cabab
simplify code
sfc-gh-aling Apr 7, 2026
7df91d6
Fix select->sort->filter where sort has more columns than select
sfc-gh-yixie Apr 13, 2026
7c436a1
Merge branch 'yixie-SNOW-2203826-flatten-filter-sort-new' of github.c…
sfc-gh-yixie Apr 13, 2026
4e594ec
resolve conflict
sfc-gh-aling Apr 13, 2026
7df994e
improve agg function retrieval
sfc-gh-aling Apr 13, 2026
7ba0dc0
Revert "SNOW-3266242: Support TRY_CAST with user-provided schema in D…
sfc-gh-aling Apr 14, 2026
82a556c
Avoid error when select from
sfc-gh-yixie Apr 14, 2026
b0c702b
fix workload
sfc-gh-mayliu Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

- Allow user input schema when reading Parquet file on stage.

#### Bug Fixes

- Fixed a bug that `TRY_CAST` reader option is ignored when calling `DataFrameReader.schema().csv()`.

#### Improvements

- Restored the following query improvements that were reverted in 1.47.0 due to bugs:
Expand Down
220 changes: 189 additions & 31 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Sequence,
Set,
Union,
Literal,
)

import snowflake.snowpark._internal.utils
Expand Down Expand Up @@ -86,6 +87,7 @@
is_sql_select_statement,
ExprAliasUpdateDict,
)
import snowflake.snowpark.context as context

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
# Python 3.9 can use both
Expand Down Expand Up @@ -1412,9 +1414,9 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
):
# TODO: Clean up, this entire if case is parameter protection
can_be_flattened = False
elif (self.where or self.order_by or self.limit_) and has_data_generator_exp(
cols
):
elif (
self.where or self.order_by or self.limit_
) and has_data_generator_or_window_function_exp(cols):
can_be_flattened = False
elif self.where and (
(subquery_dependent_columns := derive_dependent_columns(self.where))
Expand All @@ -1425,6 +1427,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
subquery_dependent_columns & new_column_states.active_columns
)
)
or (
# unflattenable condition: dropped column is used in subquery WHERE clause and dropped column status is NEW or CHANGED in the subquery
# reason: we should not flatten because the dropped column is not available in the new query, leading to WHERE clause error
# sample query: 'select "b" from (select "a" as "c", "b" from table where "c" > 1)' can not be flatten to 'select "b" from table where "c" > 1'
context._is_snowpark_connect_compatible_mode
and new_column_states.dropped_columns
and any(
self.column_states[_col].change_state
in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)
for _col in (
subquery_dependent_columns & new_column_states.dropped_columns
)
)
)
):
can_be_flattened = False
elif self.order_by and (
Expand All @@ -1437,6 +1453,20 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
subquery_dependent_columns & new_column_states.active_columns
)
)
or (
# unflattenable condition: dropped column is used in subquery ORDER BY clause and dropped column status is NEW or CHANGED in the subquery
# reason: we should not flatten because the dropped column is not available in the new query, leading to ORDER BY clause error
# sample query: 'select "b" from (select "a" as "c", "b" order by "c")' can not be flatten to 'select "b" from table order by "c"'
context._is_snowpark_connect_compatible_mode
and new_column_states.dropped_columns
and any(
self.column_states[_col].change_state
in (ColumnChangeState.NEW, ColumnChangeState.CHANGED_EXP)
for _col in (
subquery_dependent_columns & new_column_states.dropped_columns
)
)
)
):
can_be_flattened = False
elif self.distinct_:
Expand Down Expand Up @@ -1482,8 +1512,54 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
self.df_ast_ids.copy() if self.df_ast_ids is not None else None
)
else:
new_order_by = None
new_from = self
if context._is_snowpark_connect_compatible_mode and self.order_by:
order_by_dependent_columns = derive_dependent_columns(*self.order_by)
if order_by_dependent_columns in (
COLUMN_DEPENDENCY_DOLLAR,
COLUMN_DEPENDENCY_ALL,
):
new_order_by = None
elif any(
col not in self.from_.column_states
and col not in self.column_states
for col in order_by_dependent_columns
):
new_order_by = None
elif any(
_col not in self.column_states
or self.column_states[_col].change_state
in (ColumnChangeState.CHANGED_EXP, ColumnChangeState.DROPPED)
for _col in order_by_dependent_columns
):
new_from = copy(self)
missing_columns = (
order_by_dependent_columns
- new_from.column_states.active_columns
)
new_from.projection = new_from.projection + [
Attribute(col, DataType()) for col in missing_columns
]
new_col_states = derive_column_states_from_subquery(
new_from.projection, new_from.from_
)
if new_col_states is not None:
new_from.column_states = new_col_states
new_from._projection_in_str = None
new_from._commented_sql = None
new_from._sql_query = None
new_order_by = self.order_by
else:
new_from = self
new_order_by = None
else:
new_order_by = self.order_by
new = SelectStatement(
projection=cols, from_=self.to_subqueryable(), analyzer=self.analyzer
projection=cols,
from_=new_from.to_subqueryable(),
order_by=new_order_by,
analyzer=self.analyzer,
)
new._merge_projection_complexity_with_subquery = (
can_select_projection_complexity_be_merged(
Expand All @@ -1504,12 +1580,17 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
return new

def filter(self, col: Expression) -> "SelectStatement":
self._session._retrieve_aggregation_function_list()
can_be_flattened = (
(not self.flatten_disabled)
and can_clause_dependent_columns_flatten(
derive_dependent_columns(col), self.column_states
derive_dependent_columns(col), self.column_states, "filter"
)
and not has_data_generator_exp(self.projection)
and not has_data_generator_or_window_function_exp(self.projection)
and not (
context._is_snowpark_connect_compatible_mode
and has_aggregation_function_exp(self.projection)
) # sum(col) as new_col, new_col can not be flattened in where clause
and not (self.order_by and self.limit_ is not None)
)
if can_be_flattened:
Expand All @@ -1535,18 +1616,14 @@ def filter(self, col: Expression) -> "SelectStatement":
def sort(self, cols: List[Expression]) -> "SelectStatement":
can_be_flattened = (
(not self.flatten_disabled)
# limit order by and order by limit can cause big performance
# difference, because limit can stop table scanning whenever the
# number of record is satisfied.
# Therefore, disallow sql simplification when the
# current SelectStatement has a limit clause to avoid moving
# order by in front of limit.
# Disallow flattening when the current SelectStatement has a
# limit clause to avoid moving order by in front of limit.
and (not self.limit_)
and (not self.offset)
and can_clause_dependent_columns_flatten(
derive_dependent_columns(*cols), self.column_states
derive_dependent_columns(*cols), self.column_states, "sort"
)
and not has_data_generator_exp(self.projection)
and not has_data_generator_or_window_function_exp(self.projection)
)
if can_be_flattened:
new = copy(self)
Expand Down Expand Up @@ -1583,7 +1660,7 @@ def distinct(self) -> "SelectStatement":
# .order_by(col1).select(col2).distinct() cannot be flattened because
# SELECT DISTINCT B FROM TABLE ORDER BY A is not valid SQL
and (not (self.order_by and self.has_projection))
and not has_data_generator_exp(self.projection)
and not has_data_generator_or_window_function_exp(self.projection)
)
if can_be_flattened:
new = copy(self)
Expand Down Expand Up @@ -2074,7 +2151,12 @@ def can_projection_dependent_columns_be_flattened(
def can_clause_dependent_columns_flatten(
dependent_columns: Optional[AbstractSet[str]],
subquery_column_states: ColumnStateDict,
clause: Literal["filter", "sort"],
) -> bool:
assert clause in (
"filter",
"sort",
), f"Invalid clause called in can_clause_dependent_columns_flatten: {clause}"
if dependent_columns == COLUMN_DEPENDENCY_DOLLAR:
return False
elif (
Expand All @@ -2089,15 +2171,25 @@ def can_clause_dependent_columns_flatten(
dc_state = subquery_column_states.get(dc)
if dc_state:
if dc_state.change_state == ColumnChangeState.CHANGED_EXP:
return False
if clause == "filter":
return False
# sort + CHANGED_EXP: safe in SCOS mode since ORDER BY
# is evaluated after projection. Keep checking remaining
# columns though — another column may be unsafe.
elif not context._is_snowpark_connect_compatible_mode:
return False
elif dc_state.change_state == ColumnChangeState.NEW:
# Most of the time this can be flattened. But if a new column uses window function and this column
# is used in a clause, the sql doesn't work in Snowflake.
# For instance `select a, rank() over(order by b) as d from test_table where d = 1` doesn't work.
# But `select a, b as d from test_table where d = 1` works
# We can inspect whether the referenced new column uses window function. Here we are being
# conservative for now to not flatten the SQL.
return False
if clause == "sort" and dc_state.dependent_columns in (
COLUMN_DEPENDENCY_DOLLAR,
COLUMN_DEPENDENCY_ALL,
):
# Scalar subqueries in sort can trigger Snowflake
# internal errors when ORDER BY references them
# at the same SELECT level.
return False
if not context._is_snowpark_connect_compatible_mode:
return False

return True


Expand Down Expand Up @@ -2321,23 +2413,89 @@ def derive_column_states_from_subquery(
return column_states


def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
def _check_expressions_for_types(
expressions: Optional[List["Expression"]],
check_data_gen: bool = False,
check_window: bool = False,
check_aggregation: bool = False,
) -> bool:
"""Efficiently check if expressions contain specific types in a single pass.

Args:
expressions: List of expressions to check
check_data_gen: Check for data generator functions
check_window: Check for window functions
check_aggregation: Check for aggregation functions

Returns:
True if any requested type is found
"""
if expressions is None:
return False

for exp in expressions:
if isinstance(exp, WindowExpression):
if exp is None:
continue

# Check window functions
if check_window and isinstance(exp, WindowExpression):
return True
if isinstance(exp, FunctionExpression) and (
exp.is_data_generator
or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION

if check_data_gen:
if isinstance(exp, FunctionExpression) and (
exp.is_data_generator
or exp.name.lower() in SEQUENCE_DEPENDENT_DATA_GENERATION
):
# https://docs.snowflake.com/en/sql-reference/functions-data-generation
return True

# Check aggregation functions
if check_aggregation and isinstance(exp, FunctionExpression):
if exp.name.lower() in context._aggregation_function_set:
return True

# Recursively check children.
# Some expression types (e.g. CaseWhen) store sub-expressions in
# _child_expressions rather than children; fall back to that.
sub_exps = exp.children
if not sub_exps and hasattr(exp, "_child_expressions"):
sub_exps = exp._child_expressions
if _check_expressions_for_types(
sub_exps, check_data_gen, check_window, check_aggregation
):
# https://docs.snowflake.com/en/sql-reference/functions-data-generation
return True
if exp is not None and has_data_generator_exp(exp.children):
return True

return False


def has_data_generator_exp(expressions: Optional[List["Expression"]]) -> bool:
"""Check if expressions contain data generator functions.

Note:
In non-connect mode, window expressions are also treated as data generators
for backward compatibility.
"""
if not context._is_snowpark_connect_compatible_mode:
return _check_expressions_for_types(
expressions, check_data_gen=True, check_window=True
)
return _check_expressions_for_types(expressions, check_data_gen=True)


def has_data_generator_or_window_function_exp(
expressions: Optional[List["Expression"]],
) -> bool:
"""Check if expressions contain data generators or window functions."""
return _check_expressions_for_types(
expressions, check_data_gen=True, check_window=True
)


def has_aggregation_function_exp(expressions: Optional[List["Expression"]]) -> bool:
"""Check if expressions contain aggregation functions."""
return _check_expressions_for_types(expressions, check_aggregation=True)


def has_nondeterministic_data_generation_exp(
expressions: Optional[List["Expression"]],
) -> bool:
Expand Down
Loading
Loading