Skip to content
Draft
118 changes: 93 additions & 25 deletions src/snowflake/snowpark/_internal/analyzer/metadata_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
from collections import Counter
from enum import Enum
from dataclasses import dataclass
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union
from snowflake.snowpark.types import DataType
from snowflake.snowpark.types import ArrayType, DataType, MapType, StructType

from snowflake.snowpark._internal.analyzer.expression import (
Attribute,
Expand All @@ -13,7 +14,7 @@
Star,
)
from snowflake.snowpark._internal.analyzer.datatype_mapper import to_sql
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
from snowflake.snowpark._internal.analyzer.unary_expression import Alias, Cast
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
Limit,
LogicalPlan,
Expand Down Expand Up @@ -86,39 +87,111 @@ def infer_quoted_identifiers_from_expressions(

def _extract_inferable_attribute_names(
attributes: Optional[List[Expression]],
from_attributes: Optional[List[Attribute]] = None,
) -> tuple[Optional[List[Attribute]], Optional[List[Attribute]]]:
"""
Returns a list of attribute names that can be infered from a list of Expressions.
Returns None if one or more attributes cannot be infered.
Returns a tuple of (expected_attributes, resolved_attributes) that can be
inferred from a list of Expressions.
- expected_attributes: Attributes that were already resolved (direct column refs)
- resolved_attributes: All attributes in projection order, with types resolved
Returns (None, None) if one or more attributes cannot be inferred.
"""
if attributes is None:
return None, None

new_attributes = []
old_attributes = []
# For Alias(Attribute(...), ...), we copy the parent's type from from_attributes by
# quoted name. That is only defined when that name appears once in FROM; if it appears
# more than once (e.g. table-function column overlap), name lookup is ambiguous — in
# that case we fail inference only when we actually need that lookup (see below).
name_counts: Counter[str] = Counter()
from_attr_map: Dict[str, Attribute] = {}
if from_attributes:
name_counts = Counter(a.name for a in from_attributes)
from_attr_map = {a.name: a for a in from_attributes if name_counts[a.name] == 1}

expected_attributes = []
resolved_in_order = []
for attr in attributes:
# Attributes are already resolved and don't require inferrence
if isinstance(attr, Attribute):
old_attributes.append(attr)
expected_attributes.append(attr)
resolved_in_order.append(attr)
continue

if isinstance(attr, Alias):
# If the first non-aliased child of an Alias node is Literal or Attribute
# the column can be inferred.
if isinstance(attr.child, (Literal, Attribute)) and attr.datatype:
attr = Attribute(attr.name, attr.datatype, attr.nullable)
if isinstance(attr.child, Attribute):
# Resolve type from from_attributes by matching child name.
if name_counts[attr.child.name] > 1:
return None, None
if attr.child.name in from_attr_map:
parent = from_attr_map[attr.child.name]
attr = Attribute(attr.name, parent.datatype, parent.nullable)
elif from_attributes is not None:
# FROM schema is known but doesn't contain this column name.
# This shouldn't happen in normal operation — bail out rather
# than proceed with a potentially wrong placeholder type.
return None, None
elif attr.datatype:
attr = Attribute(attr.name, attr.datatype, attr.nullable)
elif isinstance(attr.child, Literal):
if attr.datatype:
attr = Attribute(attr.name, attr.datatype, attr.nullable)
elif isinstance(attr.child, Cast):
# Use the Cast's declared target type for scalar types.
# Structured types (ArrayType, MapType, StructType) are excluded
# because Snowflake may promote their element types at execution
# time (e.g., ArrayType(IntegerType()) -> ArrayType(LongType())),
# so Cast.to may not match the server-returned type. For those,
# we fall through and let a describe query get the actual type.
if type(attr.child.to) is not DataType and not isinstance(
attr.child.to, (ArrayType, MapType, StructType)
):
attr = Attribute(attr.name, attr.child.to, attr.nullable)
elif isinstance(attr, Literal) and type(attr.datatype) != DataType:
# Names of literal values can be inferred
attr = Attribute(
to_sql(attr.value, attr.datatype), attr.datatype, attr.nullable
)

# If the attr has been coerced to attribute then it has been inferred
# If the attr has been coerced to attribute then it has been inferred.
if isinstance(attr, Attribute):
new_attributes.append(attr)
resolved_in_order.append(attr)
else:
return None, None
return old_attributes, new_attributes
# Every item in attributes was resolved; len(resolved_in_order) == len(attributes).
return expected_attributes, resolved_in_order


def try_infer_attributes_from_flattened_projection(
final_projection: Optional[List[Expression]],
parent_attributes: Optional[List[Attribute]],
) -> Optional[List[Attribute]]:
"""Re-derive attributes after a select() flattening, using parent types.

When CTE optimization is enabled the ``SelectStatement.select()`` flattening
path must invalidate cached ``_attributes`` because the projection changed.
Rather than unconditionally setting ``_attributes = None`` (which forces a
DESCRIBE query), this helper tries to resolve column types from the parent's
already-known attributes via ``_extract_inferable_attribute_names``.
Only succeeds when every column in the new projection can be fully resolved
to a concrete type (Alias with a resolved Attribute child, or Literal).
Returns ``None`` whenever any column cannot be resolved -- preserving the
existing describe-query behaviour as a safe default.
"""
if final_projection is None or parent_attributes is None:
return None

_, resolved = _extract_inferable_attribute_names(
final_projection, parent_attributes
)

if resolved is not None and all(
isinstance(a, Attribute) and type(a.datatype) is not DataType for a in resolved
):
return resolved

return None


def _extract_selectable_attributes(
Expand All @@ -142,28 +215,23 @@ def _extract_selectable_attributes(
else:
# Get the attributes from the child plan
from_attributes = _extract_selectable_attributes(current_plan.from_)
(
expected_attributes,
new_attributes,
# Extract expected attributes and knowable new attributes
# from current plan
) = _extract_inferable_attribute_names(current_plan.projection)
(expected_attributes, resolved,) = _extract_inferable_attribute_names(
current_plan.projection, from_attributes
)
# Check that the expected attributes match the attributes from the child plan
if (
from_attributes is not None
and expected_attributes is not None
and new_attributes is not None
and resolved is not None
):
missing_attrs = {attr.name for attr in expected_attributes} - {
attr.name for attr in from_attributes
}
if not missing_attrs and all(
isinstance(attr, (Attribute, Alias))
# If the attribute datatype is specifically DataType then it is not fully resolved
and type(attr.datatype) is not DataType
for attr in current_plan.projection or []
isinstance(attr, Attribute) and type(attr.datatype) is not DataType
for attr in resolved
):
attributes = current_plan.projection # type: ignore
attributes = resolved
elif (
isinstance(
current_plan,
Expand Down
17 changes: 12 additions & 5 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,11 +1460,6 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
if can_be_flattened:
new = copy(self)
final_projection = []
if (
self._session.reduce_describe_query_enabled
and self._session.cte_optimization_enabled
):
new._attributes = None # reset attributes since projection changed
assert new_column_states is not None
for col, state in new_column_states.items():
if state.change_state in (
Expand All @@ -1477,6 +1472,18 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
copy(self.column_states[col].expression)
) # add subquery's expression for this column name

if (
self._session.reduce_describe_query_enabled
and self._session.cte_optimization_enabled
):
from snowflake.snowpark._internal.analyzer.metadata_utils import (
try_infer_attributes_from_flattened_projection,
)

new._attributes = try_infer_attributes_from_flattened_projection(
final_projection,
self._attributes,
)
new.projection = final_projection
new.from_ = self.from_.to_subqueryable()
new.pre_actions = new.from_.pre_actions
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_binary(session, type, action):

def test_join_with_alias_dataframe(session):
expected_describe_count = (
3
1
if (session.reduce_describe_query_enabled and session.sql_simplifier_enabled)
else 4
)
Expand Down
48 changes: 23 additions & 25 deletions tests/integ/test_eager_schema_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,39 @@
)
@pytest.mark.parametrize("debug_mode", [True, False])
@pytest.mark.parametrize(
"transform",
"transform, infers_without_debug",
[
pytest.param(lambda x: copy(x), id="copy"),
pytest.param(lambda x: x.to_df(["C", "D"]), id="to_df"),
pytest.param(lambda x: x.distinct(), id="distinct"),
pytest.param(lambda x: x.drop_duplicates(), id="drop_duplicates"),
pytest.param(lambda x: x.limit(1), id="limit"),
pytest.param(lambda x: x.union(x), id="union"),
pytest.param(lambda x: x.union_all(x), id="union_all"),
pytest.param(lambda x: x.union_by_name(x), id="union_by_name"),
pytest.param(lambda x: x.union_all_by_name(x), id="union_all_by_name"),
pytest.param(lambda x: x.intersect(x), id="intersect"),
pytest.param(lambda x: x.natural_join(x), id="natural_join"),
pytest.param(lambda x: x.cross_join(x), id="cross_join"),
pytest.param(lambda x: x.sample(n=1), id="sample"),
pytest.param(lambda x: copy(x), False, id="copy"),
pytest.param(lambda x: x.to_df(["C", "D"]), True, id="to_df"),
pytest.param(lambda x: x.distinct(), False, id="distinct"),
pytest.param(lambda x: x.drop_duplicates(), False, id="drop_duplicates"),
pytest.param(lambda x: x.limit(1), False, id="limit"),
pytest.param(lambda x: x.union(x), False, id="union"),
pytest.param(lambda x: x.union_all(x), False, id="union_all"),
pytest.param(lambda x: x.union_by_name(x), False, id="union_by_name"),
pytest.param(lambda x: x.union_all_by_name(x), False, id="union_all_by_name"),
pytest.param(lambda x: x.intersect(x), False, id="intersect"),
pytest.param(lambda x: x.natural_join(x), False, id="natural_join"),
pytest.param(lambda x: x.cross_join(x), False, id="cross_join"),
pytest.param(lambda x: x.sample(n=1), False, id="sample"),
pytest.param(
lambda x: x.with_column_renamed(col("A"), "B"), id="with_column_renamed"
lambda x: x.with_column_renamed(col("A"), "B"),
False,
id="with_column_renamed",
),
# Unpivot already validates names
pytest.param(lambda x: x.unpivot("x", "y", ["A"]), id="unpivot"),
# The following functions do not error early because their schema_query do not contain
# information about the transformation being called.
pytest.param(lambda x: x.drop(col("A")), id="drop"),
pytest.param(lambda x: x.filter(col("A") == lit(1)), id="filter"),
pytest.param(lambda x: x.sort(col("A").desc()), id="sort"),
pytest.param(lambda x: x.unpivot("x", "y", ["A"]), False, id="unpivot"),
pytest.param(lambda x: x.drop(col("A")), False, id="drop"),
pytest.param(lambda x: x.filter(col("A") == lit(1)), False, id="filter"),
pytest.param(lambda x: x.sort(col("A").desc()), False, id="sort"),
],
)
def test_early_attributes(session, transform, debug_mode):
def test_early_attributes(session, transform, infers_without_debug, debug_mode):
with patch.object(context, "_debug_eager_schema_validation", debug_mode):
df = session.create_dataframe([(1, "A"), (2, "B"), (3, "C")], ["A", "B"])

transformed = transform(df)

# When debug mode is enabled the dataframe plan attributes are populated early
if debug_mode:
if debug_mode or infers_without_debug:
assert transformed._plan._metadata.attributes is not None
else:
assert transformed._plan._metadata.attributes is None
Expand Down
Loading
Loading