From dedbc1da4cbd4c83f8d21cd4d2ee881efb506f7a Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Mon, 13 Apr 2026 11:30:59 -0700 Subject: [PATCH 1/9] reuse attributes --- .../_internal/analyzer/metadata_utils.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index fd0950dbd8..adc59ef29b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -86,6 +86,7 @@ def infer_quoted_identifiers_from_expressions( def _extract_inferable_attribute_names( attributes: Optional[List[Expression]], + from_attributes: Optional[List[Attribute]] = None, ) -> tuple[Optional[List[Attribute]], Optional[List[Attribute]]]: """ Returns a list of attribute names that can be infered from a list of Expressions. @@ -94,6 +95,8 @@ def _extract_inferable_attribute_names( if attributes is None: return None, None + from_attr_map = {a.name: a for a in from_attributes} if from_attributes else None + new_attributes = [] old_attributes = [] for attr in attributes: @@ -105,7 +108,20 @@ def _extract_inferable_attribute_names( if isinstance(attr, Alias): # If the first non-aliased child of an Alias node is Literal or Attribute # the column can be inferred. - if isinstance(attr.child, (Literal, Attribute)) and attr.datatype: + if ( + isinstance(attr.child, (Literal, Attribute)) + and attr.datatype + and type(attr.datatype) is not DataType + ): + attr = Attribute(attr.name, attr.datatype, attr.nullable) + elif ( + isinstance(attr.child, Attribute) + and from_attr_map is not None + and attr.child.name in from_attr_map + ): + parent = from_attr_map[attr.child.name] + attr = Attribute(attr.name, parent.datatype, parent.nullable) + elif isinstance(attr.child, (Literal, Attribute)) and attr.datatype: attr = Attribute(attr.name, attr.datatype, attr.nullable) elif isinstance(attr, Literal) and type(attr.datatype) != DataType: # Names of literal values can be inferred @@ -142,12 +158,9 @@ def _extract_selectable_attributes( else: # Get the attributes from the child plan from_attributes = _extract_selectable_attributes(current_plan.from_) - ( - expected_attributes, - new_attributes, - # Extract expected attributes and knowable new attributes - # from current plan - ) = _extract_inferable_attribute_names(current_plan.projection) + (expected_attributes, new_attributes,) = _extract_inferable_attribute_names( + current_plan.projection, from_attributes + ) # Check that the expected attributes match the attributes from the child plan if ( from_attributes is not None @@ -157,13 +170,12 @@ def _extract_selectable_attributes( missing_attrs = {attr.name for attr in expected_attributes} - { attr.name for attr in from_attributes } + resolved = expected_attributes + new_attributes if not missing_attrs and all( - isinstance(attr, (Attribute, Alias)) - # If the attribute datatype is specifically DataType then it is not fully resolved - and type(attr.datatype) is not DataType - for attr in current_plan.projection or [] + isinstance(attr, Attribute) and type(attr.datatype) is not DataType + for attr in resolved ): - attributes = current_plan.projection # type: ignore + attributes = resolved elif ( isinstance( current_plan, From 5839ee6ea2c23187e46de210a7418ccfdcfff9b9 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Tue, 14 Apr 2026 21:00:51 -0700 Subject: [PATCH 2/9] update --- .../_internal/analyzer/metadata_utils.py | 34 +++++-------- tests/integ/test_cte.py | 2 +- tests/integ/test_eager_schema_validation.py | 48 +++++++++---------- 3 files changed, 37 insertions(+), 47 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index adc59ef29b..2609617160 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -89,32 +89,27 @@ def _extract_inferable_attribute_names( from_attributes: Optional[List[Attribute]] = None, ) -> tuple[Optional[List[Attribute]], Optional[List[Attribute]]]: """ - Returns a list of attribute names that can be infered from a list of Expressions. - Returns None if one or more attributes cannot be infered. + Returns a tuple of (expected_attributes, resolved_attributes) that can be + inferred from a list of Expressions. + - expected_attributes: Attributes that were already resolved (direct column refs) + - resolved_attributes: All attributes in projection order, with types resolved + Returns (None, None) if one or more attributes cannot be inferred. """ if attributes is None: return None, None from_attr_map = {a.name: a for a in from_attributes} if from_attributes else None - new_attributes = [] - old_attributes = [] + expected_attributes = [] + resolved_in_order = [] for attr in attributes: - # Attributes are already resolved and don't require inferrence if isinstance(attr, Attribute): - old_attributes.append(attr) + expected_attributes.append(attr) + resolved_in_order.append(attr) continue if isinstance(attr, Alias): - # If the first non-aliased child of an Alias node is Literal or Attribute - # the column can be inferred. if ( - isinstance(attr.child, (Literal, Attribute)) - and attr.datatype - and type(attr.datatype) is not DataType - ): - attr = Attribute(attr.name, attr.datatype, attr.nullable) - elif ( isinstance(attr.child, Attribute) and from_attr_map is not None and attr.child.name in from_attr_map @@ -124,17 +119,15 @@ def _extract_inferable_attribute_names( elif isinstance(attr.child, (Literal, Attribute)) and attr.datatype: attr = Attribute(attr.name, attr.datatype, attr.nullable) elif isinstance(attr, Literal) and type(attr.datatype) != DataType: - # Names of literal values can be inferred attr = Attribute( to_sql(attr.value, attr.datatype), attr.datatype, attr.nullable ) - # If the attr has been coerced to attribute then it has been inferred if isinstance(attr, Attribute): - new_attributes.append(attr) + resolved_in_order.append(attr) else: return None, None - return old_attributes, new_attributes + return expected_attributes, resolved_in_order def _extract_selectable_attributes( @@ -158,19 +151,18 @@ def _extract_selectable_attributes( else: # Get the attributes from the child plan from_attributes = _extract_selectable_attributes(current_plan.from_) - (expected_attributes, new_attributes,) = _extract_inferable_attribute_names( + (expected_attributes, resolved,) = _extract_inferable_attribute_names( current_plan.projection, from_attributes ) # Check that the expected attributes match the attributes from the child plan if ( from_attributes is not None and expected_attributes is not None - and new_attributes is not None + and resolved is not None ): missing_attrs = {attr.name for attr in expected_attributes} - { attr.name for attr in from_attributes } - resolved = expected_attributes + new_attributes if not missing_attrs and all( isinstance(attr, Attribute) and type(attr.datatype) is not DataType for attr in resolved diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 09639160dc..261706d3b1 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -259,7 +259,7 @@ def test_binary(session, type, action): def test_join_with_alias_dataframe(session): expected_describe_count = ( - 3 + 1 if (session.reduce_describe_query_enabled and session.sql_simplifier_enabled) else 4 ) diff --git a/tests/integ/test_eager_schema_validation.py b/tests/integ/test_eager_schema_validation.py index 6f7a18541e..835fc89634 100644 --- a/tests/integ/test_eager_schema_validation.py +++ b/tests/integ/test_eager_schema_validation.py @@ -18,41 +18,39 @@ ) @pytest.mark.parametrize("debug_mode", [True, False]) @pytest.mark.parametrize( - "transform", + "transform, infers_without_debug", [ - pytest.param(lambda x: copy(x), id="copy"), - pytest.param(lambda x: x.to_df(["C", "D"]), id="to_df"), - pytest.param(lambda x: x.distinct(), id="distinct"), - pytest.param(lambda x: x.drop_duplicates(), id="drop_duplicates"), - pytest.param(lambda x: x.limit(1), id="limit"), - pytest.param(lambda x: x.union(x), id="union"), - pytest.param(lambda x: x.union_all(x), id="union_all"), - pytest.param(lambda x: x.union_by_name(x), id="union_by_name"), - pytest.param(lambda x: x.union_all_by_name(x), id="union_all_by_name"), - pytest.param(lambda x: x.intersect(x), id="intersect"), - pytest.param(lambda x: x.natural_join(x), id="natural_join"), - pytest.param(lambda x: x.cross_join(x), id="cross_join"), - pytest.param(lambda x: x.sample(n=1), id="sample"), + pytest.param(lambda x: copy(x), False, id="copy"), + pytest.param(lambda x: x.to_df(["C", "D"]), True, id="to_df"), + pytest.param(lambda x: x.distinct(), False, id="distinct"), + pytest.param(lambda x: x.drop_duplicates(), False, id="drop_duplicates"), + pytest.param(lambda x: x.limit(1), False, id="limit"), + pytest.param(lambda x: x.union(x), False, id="union"), + pytest.param(lambda x: x.union_all(x), False, id="union_all"), + pytest.param(lambda x: x.union_by_name(x), False, id="union_by_name"), + pytest.param(lambda x: x.union_all_by_name(x), False, id="union_all_by_name"), + pytest.param(lambda x: x.intersect(x), False, id="intersect"), + pytest.param(lambda x: x.natural_join(x), False, id="natural_join"), + pytest.param(lambda x: x.cross_join(x), False, id="cross_join"), + pytest.param(lambda x: x.sample(n=1), False, id="sample"), pytest.param( - lambda x: x.with_column_renamed(col("A"), "B"), id="with_column_renamed" + lambda x: x.with_column_renamed(col("A"), "B"), + False, + id="with_column_renamed", ), - # Unpivot already validates names - pytest.param(lambda x: x.unpivot("x", "y", ["A"]), id="unpivot"), - # The following functions do not error early because their schema_query do not contain - # information about the transformation being called. - pytest.param(lambda x: x.drop(col("A")), id="drop"), - pytest.param(lambda x: x.filter(col("A") == lit(1)), id="filter"), - pytest.param(lambda x: x.sort(col("A").desc()), id="sort"), + pytest.param(lambda x: x.unpivot("x", "y", ["A"]), False, id="unpivot"), + pytest.param(lambda x: x.drop(col("A")), False, id="drop"), + pytest.param(lambda x: x.filter(col("A") == lit(1)), False, id="filter"), + pytest.param(lambda x: x.sort(col("A").desc()), False, id="sort"), ], ) -def test_early_attributes(session, transform, debug_mode): +def test_early_attributes(session, transform, infers_without_debug, debug_mode): with patch.object(context, "_debug_eager_schema_validation", debug_mode): df = session.create_dataframe([(1, "A"), (2, "B"), (3, "C")], ["A", "B"]) transformed = transform(df) - # When debug mode is enabled the dataframe plan attributes are populated early - if debug_mode: + if debug_mode or infers_without_debug: assert transformed._plan._metadata.attributes is not None else: assert transformed._plan._metadata.attributes is None From d5effafee2cf428744fdced68d1eb7463e747763 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Thu, 16 Apr 2026 15:13:05 -0700 Subject: [PATCH 3/9] update --- .../_internal/analyzer/metadata_utils.py | 35 ++++++++++++++++++- .../_internal/analyzer/select_statement.py | 17 ++++++--- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index 2609617160..8794f8b0fd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -13,7 +13,7 @@ Star, ) from snowflake.snowpark._internal.analyzer.datatype_mapper import to_sql -from snowflake.snowpark._internal.analyzer.unary_expression import Alias +from snowflake.snowpark._internal.analyzer.unary_expression import Alias, Cast from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( Limit, LogicalPlan, @@ -116,6 +116,8 @@ def _extract_inferable_attribute_names( ): parent = from_attr_map[attr.child.name] attr = Attribute(attr.name, parent.datatype, parent.nullable) + elif isinstance(attr.child, Cast) and type(attr.child.to) is not DataType: + attr = Attribute(attr.name, attr.child.to, attr.nullable) elif isinstance(attr.child, (Literal, Attribute)) and attr.datatype: attr = Attribute(attr.name, attr.datatype, attr.nullable) elif isinstance(attr, Literal) and type(attr.datatype) != DataType: @@ -130,6 +132,37 @@ def _extract_inferable_attribute_names( return expected_attributes, resolved_in_order +def try_infer_attributes_from_flattened_projection( + final_projection: Optional[List[Expression]], + parent_attributes: Optional[List[Attribute]], +) -> Optional[List[Attribute]]: + """Re-derive attributes after a select() flattening, using parent types. + + When CTE optimization is enabled the ``SelectStatement.select()`` flattening + path must invalidate cached ``_attributes`` because the projection changed. + Rather than unconditionally setting ``_attributes = None`` (which forces a + DESCRIBE query), this helper tries to resolve column types from the parent's + already-known attributes via ``_extract_inferable_attribute_names``. + Only succeeds when every column in the new projection can be fully resolved + to a concrete type (Alias with a resolved Attribute child, or Literal). + Returns ``None`` whenever any column cannot be resolved -- preserving the + existing describe-query behaviour as a safe default. + """ + if final_projection is None or parent_attributes is None: + return None + + _, resolved = _extract_inferable_attribute_names( + final_projection, parent_attributes + ) + + if resolved is not None and all( + isinstance(a, Attribute) and type(a.datatype) is not DataType for a in resolved + ): + return resolved + + return None + + def _extract_selectable_attributes( current_plan: LogicalPlan, ) -> Optional[List[Attribute]]: diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index f2f8f4a4b2..71368c59d7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1460,11 +1460,6 @@ def select(self, cols: List[Expression]) -> "SelectStatement": if can_be_flattened: new = copy(self) final_projection = [] - if ( - self._session.reduce_describe_query_enabled - and self._session.cte_optimization_enabled - ): - new._attributes = None # reset attributes since projection changed assert new_column_states is not None for col, state in new_column_states.items(): if state.change_state in ( @@ -1477,6 +1472,18 @@ def select(self, cols: List[Expression]) -> "SelectStatement": copy(self.column_states[col].expression) ) # add subquery's expression for this column name + if ( + self._session.reduce_describe_query_enabled + and self._session.cte_optimization_enabled + ): + from snowflake.snowpark._internal.analyzer.metadata_utils import ( + try_infer_attributes_from_flattened_projection, + ) + + new._attributes = try_infer_attributes_from_flattened_projection( + final_projection, + self._attributes, + ) new.projection = final_projection new.from_ = self.from_.to_subqueryable() new.pre_actions = new.from_.pre_actions From d7b0739a32f724ced89f3211492eb58172a62867 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Thu, 16 Apr 2026 17:26:49 -0700 Subject: [PATCH 4/9] fix structure type --- .../snowpark/_internal/analyzer/metadata_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index 8794f8b0fd..9eb10dec75 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -4,7 +4,7 @@ from enum import Enum from dataclasses import dataclass from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union -from snowflake.snowpark.types import DataType +from snowflake.snowpark.types import ArrayType, DataType, MapType, StructType from snowflake.snowpark._internal.analyzer.expression import ( Attribute, @@ -116,7 +116,11 @@ def _extract_inferable_attribute_names( ): parent = from_attr_map[attr.child.name] attr = Attribute(attr.name, parent.datatype, parent.nullable) - elif isinstance(attr.child, Cast) and type(attr.child.to) is not DataType: + elif ( + isinstance(attr.child, Cast) + and type(attr.child.to) is not DataType + and not isinstance(attr.child.to, (ArrayType, MapType, StructType)) + ): attr = Attribute(attr.name, attr.child.to, attr.nullable) elif isinstance(attr.child, (Literal, Attribute)) and attr.datatype: attr = Attribute(attr.name, attr.datatype, attr.nullable) From b2c25e19e424a4c51bc2f064aee1ea3fc6971db1 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Thu, 16 Apr 2026 23:12:23 -0700 Subject: [PATCH 5/9] update --- .../_internal/analyzer/metadata_utils.py | 16 +- tests/integ/test_reduce_describe_query.py | 59 +++++- tests/unit/test_metadata_utils.py | 175 ++++++++++++++++++ 3 files changed, 243 insertions(+), 7 deletions(-) create mode 100644 tests/unit/test_metadata_utils.py diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index 9eb10dec75..db45f5794b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -98,7 +98,11 @@ def _extract_inferable_attribute_names( if attributes is None: return None, None - from_attr_map = {a.name: a for a in from_attributes} if from_attributes else None + from_attr_map = {a.name: a for a in from_attributes} if from_attributes else {} + assert not from_attributes or len(from_attr_map) == len(from_attributes), ( + f"Unexpected duplicate column names in from_attributes: " + f"{[a.name for a in from_attributes]}" + ) expected_attributes = [] resolved_in_order = [] @@ -109,11 +113,11 @@ def _extract_inferable_attribute_names( continue if isinstance(attr, Alias): - if ( - isinstance(attr.child, Attribute) - and from_attr_map is not None - and attr.child.name in from_attr_map - ): + # In the SQL simplifier model, a SelectStatement's projection can only + # reference columns from its FROM clause. So attr.child (an Attribute) + # is always a reference to a column in from_attributes, and the name-based + # lookup is safe because from_attributes names are unique (asserted above). + if isinstance(attr.child, Attribute) and attr.child.name in from_attr_map: parent = from_attr_map[attr.child.name] attr = Attribute(attr.name, parent.datatype, parent.nullable) elif ( diff --git a/tests/integ/test_reduce_describe_query.py b/tests/integ/test_reduce_describe_query.py index 5fe9e033b2..bb16a3ba82 100644 --- a/tests/integ/test_reduce_describe_query.py +++ b/tests/integ/test_reduce_describe_query.py @@ -34,7 +34,7 @@ _PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, Session, ) -from snowflake.snowpark.types import LongType, StructField, StructType +from snowflake.snowpark.types import LongType, StringType, StructField, StructType from tests.integ.utils.sql_counter import SqlCounter from tests.utils import IS_IN_STORED_PROC, TestData @@ -493,3 +493,60 @@ def test_update_schema_query_when_attributes_available(session): ) if should_simplify: assert df._plan.schema_query == simplified_schema_query2 + + +def test_infer_cast_type_reduces_describe(session): + """Verify that Alias(Cast) inference avoids extra describe queries, + and that results/schema are correct across various projection patterns.""" + df = session.create_dataframe( + [[1, 2, "a"], [3, 4, "b"]], + schema=StructType( + [ + StructField("A", LongType()), + StructField("B", LongType()), + StructField("NAME", StringType()), + ] + ), + ) + + # Resolve schema once (may need describe for typed StructType + Cast projections) + schema = df.schema + assert schema[0].datatype == LongType() + assert schema[2].datatype == StringType() + + # 1. Self-join via alias -- both sides reference the same createDataFrame. + # Cast inference should avoid extra describes for the second reference. + df_l = df.alias("L") + df_r = df.alias("R") + joined = df_l.join(df_r, col("L", "A") == col("R", "A")).select( + col("L", "A"), col("R", "B"), col("L", "NAME") + ) + result = joined.collect() + assert len(result) > 0 + + # 2. Select with rename -- aliasing should not break inference + renamed = df.select(col("A").alias("X"), col("B")) + schema = renamed.schema + assert schema[0].name == "X" + assert schema[0].datatype == LongType() + assert schema[1].datatype == LongType() + + # 3. Filter preserves schema without extra describe + filtered = df.filter(col("A") > 1) + if session.reduce_describe_query_enabled: + with SqlCounter(query_count=0, describe_count=0): + schema = filtered.schema + assert schema == df.schema + else: + schema = filtered.schema + assert schema == df.schema + + # 4. Verify correctness of cast inference -- types must match server types + cast_df = df.select( + col("A").cast(StringType()).alias("A_STR"), + col("NAME"), + ) + result = cast_df.collect() + assert result[0]["A_STR"] == "1" + assert cast_df.schema[0].datatype == StringType() + assert cast_df.schema[1].datatype == StringType() diff --git a/tests/unit/test_metadata_utils.py b/tests/unit/test_metadata_utils.py new file mode 100644 index 0000000000..a20a89ea07 --- /dev/null +++ b/tests/unit/test_metadata_utils.py @@ -0,0 +1,175 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import pytest + +from snowflake.snowpark._internal.analyzer.expression import ( + Attribute, + Expression, + Literal, +) +from snowflake.snowpark._internal.analyzer.metadata_utils import ( + _extract_inferable_attribute_names, +) +from snowflake.snowpark._internal.analyzer.unary_expression import Alias, Cast +from snowflake.snowpark.types import ( + ArrayType, + DataType, + DoubleType, + IntegerType, + MapType, + StringType, + StructType, +) + + +def test_none_input(): + assert _extract_inferable_attribute_names(None) == (None, None) + + +def test_plain_attribute_passthrough(): + attrs = [Attribute('"A"', IntegerType()), Attribute('"B"', StringType())] + expected, resolved = _extract_inferable_attribute_names(attrs) + assert len(expected) == 2 + assert len(resolved) == 2 + assert resolved[0].name == '"A"' + assert resolved[0].datatype == IntegerType() + assert resolved[1].name == '"B"' + assert resolved[1].datatype == StringType() + + +def test_alias_attribute_resolved_from_parent(): + """Alias(Attribute) resolves type from from_attributes by name.""" + child = Attribute('"A"', DataType()) + projection = [Alias(child, '"X"')] + from_attributes = [Attribute('"A"', IntegerType(), nullable=False)] + + expected, resolved = _extract_inferable_attribute_names(projection, from_attributes) + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].name == '"X"' + assert resolved[0].datatype == IntegerType() + + +def test_alias_attribute_no_match_in_parent(): + """Alias(Attribute) with no matching name in from_attributes falls through + to the Alias(Literal/Attribute)+datatype branch. With DataType() (dummy), + it resolves but with a non-concrete type that downstream code rejects.""" + child = Attribute('"MISSING"', DataType()) + projection = [Alias(child, '"X"')] + from_attributes = [Attribute('"A"', IntegerType())] + + expected, resolved = _extract_inferable_attribute_names(projection, from_attributes) + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].name == '"X"' + assert type(resolved[0].datatype) is DataType + + +def test_alias_cast_scalar_type(): + """Alias(Cast(to=scalar_type)) resolves to the Cast target type.""" + inner = Attribute('"A"', DataType()) + cast = Cast(inner, IntegerType()) + projection = [Alias(cast, '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].name == '"X"' + assert resolved[0].datatype == IntegerType() + + +@pytest.mark.parametrize( + "structured_type", + [ + ArrayType(IntegerType()), + MapType(StringType(), IntegerType()), + StructType(), + ], + ids=["ArrayType", "MapType", "StructType"], +) +def test_alias_cast_structured_type_skipped(structured_type): + """Alias(Cast(to=structured_type)) is not inferred -- falls back to describe.""" + inner = Attribute('"A"', DataType()) + cast = Cast(inner, structured_type) + projection = [Alias(cast, '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert (expected, resolved) == (None, None) + + +def test_alias_cast_dummy_datatype_skipped(): + """Alias(Cast(to=DataType())) is not inferred -- base DataType is not concrete.""" + inner = Attribute('"A"', DataType()) + cast = Cast(inner, DataType()) + projection = [Alias(cast, '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert (expected, resolved) == (None, None) + + +def test_alias_literal_with_known_type(): + """Alias(Literal) with concrete type resolves (existing behavior).""" + lit = Literal(42, IntegerType()) + projection = [Alias(lit, '"X"')] + projection[0].datatype = IntegerType() + + expected, resolved = _extract_inferable_attribute_names(projection) + assert resolved is not None + assert resolved[0].name == '"X"' + assert resolved[0].datatype == IntegerType() + + +def test_unresolvable_expression(): + """Alias with unknown child type -> (None, None).""" + + class UnknownExpr(Expression): + pass + + projection = [Alias(UnknownExpr(), '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert (expected, resolved) == (None, None) + + +def test_mixed_resolvable_and_unresolvable(): + """If any expression is unresolvable, entire result is (None, None).""" + good = Attribute('"A"', IntegerType()) + + class UnknownExpr(Expression): + pass + + bad = Alias(UnknownExpr(), '"B"') + projection = [good, bad] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert (expected, resolved) == (None, None) + + +def test_multiple_cast_types(): + """Multiple Alias(Cast) with different scalar types all resolve.""" + projection = [ + Alias(Cast(Attribute('"A"', DataType()), IntegerType()), '"X"'), + Alias(Cast(Attribute('"B"', DataType()), StringType()), '"Y"'), + Alias(Cast(Attribute('"C"', DataType()), DoubleType()), '"Z"'), + ] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert resolved is not None + assert len(resolved) == 3 + assert resolved[0].datatype == IntegerType() + assert resolved[1].datatype == StringType() + assert resolved[2].datatype == DoubleType() + + +def test_duplicate_from_attributes_raises_assertion(): + """from_attributes with duplicate names should trigger AssertionError.""" + projection = [Attribute('"A"', IntegerType())] + from_attributes = [ + Attribute('"X"', IntegerType()), + Attribute('"X"', StringType()), + ] + + with pytest.raises(AssertionError, match="Unexpected duplicate column names"): + _extract_inferable_attribute_names(projection, from_attributes) From 07c3183c99771b61e0b2d6ce12ca08d4f543055a Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 17 Apr 2026 11:35:50 -0700 Subject: [PATCH 6/9] bug fix --- .../_internal/analyzer/metadata_utils.py | 28 ++++++++++++------- tests/unit/test_metadata_utils.py | 26 ++++++++++++++--- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index db45f5794b..1c67635eb5 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -1,6 +1,7 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import Counter from enum import Enum from dataclasses import dataclass from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union @@ -98,11 +99,15 @@ def _extract_inferable_attribute_names( if attributes is None: return None, None - from_attr_map = {a.name: a for a in from_attributes} if from_attributes else {} - assert not from_attributes or len(from_attr_map) == len(from_attributes), ( - f"Unexpected duplicate column names in from_attributes: " - f"{[a.name for a in from_attributes]}" - ) + # For Alias(Attribute(...), ...), we copy the parent's type from from_attributes by + # quoted name. That is only defined when that name appears once in FROM; if it appears + # more than once (e.g. table-function column overlap), name lookup is ambiguous — in + # that case we fail inference only when we actually need that lookup (see below). + name_counts: Counter[str] = Counter() + from_attr_map: Dict[str, Attribute] = {} + if from_attributes: + name_counts = Counter(a.name for a in from_attributes) + from_attr_map = {a.name: a for a in from_attributes if name_counts[a.name] == 1} expected_attributes = [] resolved_in_order = [] @@ -115,11 +120,14 @@ def _extract_inferable_attribute_names( if isinstance(attr, Alias): # In the SQL simplifier model, a SelectStatement's projection can only # reference columns from its FROM clause. So attr.child (an Attribute) - # is always a reference to a column in from_attributes, and the name-based - # lookup is safe because from_attributes names are unique (asserted above). - if isinstance(attr.child, Attribute) and attr.child.name in from_attr_map: - parent = from_attr_map[attr.child.name] - attr = Attribute(attr.name, parent.datatype, parent.nullable) + # is usually a reference to a column in from_attributes; parent types are + # merged by name when that name is unique in FROM. + if isinstance(attr.child, Attribute): + if name_counts[attr.child.name] > 1: + return None, None + if attr.child.name in from_attr_map: + parent = from_attr_map[attr.child.name] + attr = Attribute(attr.name, parent.datatype, parent.nullable) elif ( isinstance(attr.child, Cast) and type(attr.child.to) is not DataType diff --git a/tests/unit/test_metadata_utils.py b/tests/unit/test_metadata_utils.py index a20a89ea07..a322504ecd 100644 --- a/tests/unit/test_metadata_utils.py +++ b/tests/unit/test_metadata_utils.py @@ -163,13 +163,31 @@ def test_multiple_cast_types(): assert resolved[2].datatype == DoubleType() -def test_duplicate_from_attributes_raises_assertion(): - """from_attributes with duplicate names should trigger AssertionError.""" +def test_duplicate_from_attributes_plain_projection_still_resolves(): + """Duplicate FROM names do not block inference when projection only uses plain + Attributes (no ambiguous name-based parent lookup for Alias).""" projection = [Attribute('"A"', IntegerType())] from_attributes = [ Attribute('"X"', IntegerType()), Attribute('"X"', StringType()), + Attribute('"A"', StringType()), ] - with pytest.raises(AssertionError, match="Unexpected duplicate column names"): - _extract_inferable_attribute_names(projection, from_attributes) + expected, resolved = _extract_inferable_attribute_names(projection, from_attributes) + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].name == '"A"' + + +def test_duplicate_from_attributes_alias_child_ambiguous_returns_none(): + """Alias(Attribute) cannot inherit type by name when that name is duplicated in FROM.""" + projection = [ + Alias(Attribute('"X"', DataType()), '"Y"'), + ] + from_attributes = [ + Attribute('"X"', IntegerType()), + Attribute('"X"', StringType()), + ] + + expected, resolved = _extract_inferable_attribute_names(projection, from_attributes) + assert (expected, resolved) == (None, None) From 8d709e8b221ad67c19c9fd6d189aa18d5059ae39 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 17 Apr 2026 12:53:54 -0700 Subject: [PATCH 7/9] update --- .../_internal/analyzer/metadata_utils.py | 28 +++++++++++-------- tests/unit/test_metadata_utils.py | 17 +++++++++-- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index 1c67635eb5..3740e3e3fd 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -118,24 +118,28 @@ def _extract_inferable_attribute_names( continue if isinstance(attr, Alias): - # In the SQL simplifier model, a SelectStatement's projection can only - # reference columns from its FROM clause. So attr.child (an Attribute) - # is usually a reference to a column in from_attributes; parent types are - # merged by name when that name is unique in FROM. if isinstance(attr.child, Attribute): + # Resolve type from from_attributes by matching child name. if name_counts[attr.child.name] > 1: return None, None if attr.child.name in from_attr_map: parent = from_attr_map[attr.child.name] attr = Attribute(attr.name, parent.datatype, parent.nullable) - elif ( - isinstance(attr.child, Cast) - and type(attr.child.to) is not DataType - and not isinstance(attr.child.to, (ArrayType, MapType, StructType)) - ): - attr = Attribute(attr.name, attr.child.to, attr.nullable) - elif isinstance(attr.child, (Literal, Attribute)) and attr.datatype: - attr = Attribute(attr.name, attr.datatype, attr.nullable) + elif from_attributes is not None: + # FROM schema is known but doesn't contain this column name. + # This shouldn't happen in normal operation — bail out rather + # than proceed with a potentially wrong placeholder type. + return None, None + elif attr.datatype: + attr = Attribute(attr.name, attr.datatype, attr.nullable) + elif isinstance(attr.child, Literal): + if attr.datatype: + attr = Attribute(attr.name, attr.datatype, attr.nullable) + elif isinstance(attr.child, Cast): + if type(attr.child.to) is not DataType and not isinstance( + attr.child.to, (ArrayType, MapType, StructType) + ): + attr = Attribute(attr.name, attr.child.to, attr.nullable) elif isinstance(attr, Literal) and type(attr.datatype) != DataType: attr = Attribute( to_sql(attr.value, attr.datatype), attr.datatype, attr.nullable diff --git a/tests/unit/test_metadata_utils.py b/tests/unit/test_metadata_utils.py index a322504ecd..981ea74c09 100644 --- a/tests/unit/test_metadata_utils.py +++ b/tests/unit/test_metadata_utils.py @@ -53,14 +53,25 @@ def test_alias_attribute_resolved_from_parent(): def test_alias_attribute_no_match_in_parent(): - """Alias(Attribute) with no matching name in from_attributes falls through - to the Alias(Literal/Attribute)+datatype branch. With DataType() (dummy), - it resolves but with a non-concrete type that downstream code rejects.""" + """Alias(Attribute) references a name not present in from_attributes — inference fails.""" child = Attribute('"MISSING"', DataType()) projection = [Alias(child, '"X"')] from_attributes = [Attribute('"A"', IntegerType())] expected, resolved = _extract_inferable_attribute_names(projection, from_attributes) + assert (expected, resolved) == (None, None) + + +def test_alias_attribute_no_from_attributes_uses_own_datatype(): + """Without from_attributes, Alias(Attribute) resolves via attr.datatype (main compat). + + The resolved Attribute carries a placeholder DataType(); downstream callers + reject it via ``type(attr.datatype) is not DataType``, so no incorrect + metadata is cached — but the function itself does not return (None, None).""" + child = Attribute('"MISSING"', DataType()) + projection = [Alias(child, '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) assert resolved is not None assert len(resolved) == 1 assert resolved[0].name == '"X"' From 0295332eb516b91b1f2861ca00f20aff3f71a7be Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 17 Apr 2026 15:14:51 -0700 Subject: [PATCH 8/9] update tests --- .../_internal/analyzer/metadata_utils.py | 5 + tests/integ/test_reduce_describe_query.py | 202 +++++++++++++----- 2 files changed, 154 insertions(+), 53 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index 3740e3e3fd..3f4673e866 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -118,6 +118,8 @@ def _extract_inferable_attribute_names( continue if isinstance(attr, Alias): + # If the first non-aliased child of an Alias node is Literal or Attribute + # the column can be inferred. if isinstance(attr.child, Attribute): # Resolve type from from_attributes by matching child name. if name_counts[attr.child.name] > 1: @@ -141,14 +143,17 @@ def _extract_inferable_attribute_names( ): attr = Attribute(attr.name, attr.child.to, attr.nullable) elif isinstance(attr, Literal) and type(attr.datatype) != DataType: + # Names of literal values can be inferred attr = Attribute( to_sql(attr.value, attr.datatype), attr.datatype, attr.nullable ) + # If the attr has been coerced to attribute then it has been inferred. if isinstance(attr, Attribute): resolved_in_order.append(attr) else: return None, None + # Every item in attributes was resolved; len(resolved_in_order) == len(attributes). return expected_attributes, resolved_in_order diff --git a/tests/integ/test_reduce_describe_query.py b/tests/integ/test_reduce_describe_query.py index bb16a3ba82..c050214521 100644 --- a/tests/integ/test_reduce_describe_query.py +++ b/tests/integ/test_reduce_describe_query.py @@ -495,58 +495,154 @@ def test_update_schema_query_when_attributes_available(session): assert df._plan.schema_query == simplified_schema_query2 -def test_infer_cast_type_reduces_describe(session): - """Verify that Alias(Cast) inference avoids extra describe queries, - and that results/schema are correct across various projection patterns.""" - df = session.create_dataframe( - [[1, 2, "a"], [3, 4, "b"]], - schema=StructType( - [ - StructField("A", LongType()), - StructField("B", LongType()), - StructField("NAME", StringType()), - ] - ), - ) +def test_infer_cast_type_reduces_describe_for_self_join(session): + """Verify that Alias(Cast) inference reduces describe queries for self-joins + with typed schemas, and that results are correct. + + With reduce_describe_query_enabled=True, the Cast inference avoids redundant + describes for repeated references to the same table (self-join). Without it, + each reference triggers its own describe. + """ + # Without optimization: 5 describes for create + self-join + select + collect + # With optimization: 1 describe (only the initial schema resolution) + expected_describe_count = 1 if session.reduce_describe_query_enabled else 5 + with SqlCounter( + query_count=1, describe_count=expected_describe_count, join_count=1 + ): + df = session.create_dataframe( + [[1, 2], [3, 4]], + schema=StructType( + [StructField("A", LongType()), StructField("B", LongType())] + ), + ) + joined = ( + df.alias("L") + .join(df.alias("R"), col("L", "A") == col("R", "A")) + .select(col("L", "A"), col("R", "B")) + ) + result = joined.collect() - # Resolve schema once (may need describe for typed StructType + Cast projections) - schema = df.schema - assert schema[0].datatype == LongType() - assert schema[2].datatype == StringType() - - # 1. Self-join via alias -- both sides reference the same createDataFrame. - # Cast inference should avoid extra describes for the second reference. - df_l = df.alias("L") - df_r = df.alias("R") - joined = df_l.join(df_r, col("L", "A") == col("R", "A")).select( - col("L", "A"), col("R", "B"), col("L", "NAME") - ) - result = joined.collect() assert len(result) > 0 - - # 2. Select with rename -- aliasing should not break inference - renamed = df.select(col("A").alias("X"), col("B")) - schema = renamed.schema - assert schema[0].name == "X" - assert schema[0].datatype == LongType() - assert schema[1].datatype == LongType() - - # 3. Filter preserves schema without extra describe - filtered = df.filter(col("A") > 1) - if session.reduce_describe_query_enabled: - with SqlCounter(query_count=0, describe_count=0): - schema = filtered.schema - assert schema == df.schema - else: - schema = filtered.schema - assert schema == df.schema - - # 4. Verify correctness of cast inference -- types must match server types - cast_df = df.select( - col("A").cast(StringType()).alias("A_STR"), - col("NAME"), - ) - result = cast_df.collect() - assert result[0]["A_STR"] == "1" - assert cast_df.schema[0].datatype == StringType() - assert cast_df.schema[1].datatype == StringType() + f = joined.schema.fields + assert len(f) == 2 + assert f[0] == StructField("AL", LongType(), nullable=True) + assert f[1] == StructField("BL", LongType(), nullable=True) + + +def test_infer_cast_type_correctness_across_projections(session): + """Verify that Cast/Alias inference produces correct types across + various projection patterns with CTE enabled, exercising the + try_infer_attributes_from_flattened_projection code path. + Each sub-case checks both result correctness and schema types.""" + with patch.object(session, "_cte_optimization_enabled", True): + df = session.create_dataframe( + [[1, "a"], [2, "b"]], + schema=StructType( + [StructField("A", LongType()), StructField("NAME", StringType())] + ), + ) + df2 = session.create_dataframe( + [[3, "c"]], + schema=StructType( + [StructField("A", LongType()), StructField("NAME", StringType())] + ), + ) + + # 1. cast + union: cast changes type to StringType, union preserves it. + cu = df.select(col("A").cast(StringType()).alias("A"), col("NAME")).union_all( + df2.select(col("A").cast(StringType()).alias("A"), col("NAME")) + ) + assert len(cu.collect()) == 3 + assert cu.schema == StructType( + [ + StructField("A", StringType(), nullable=True), + StructField("NAME", StringType(), nullable=True), + ] + ) + + # 2. cast + self-join: Cast inference preserves StringType through join. + casted = df.select(col("A").cast(StringType()).alias("A_STR"), col("NAME")) + cj = casted.alias("L").join( + casted.alias("R"), col("L", "A_STR") == col("R", "A_STR") + ) + assert len(cj.collect()) > 0 + f = cj.schema.fields + assert f[0] == StructField("A_STRL", StringType(), nullable=True) + assert f[1] == StructField("NAMEL", StringType(), nullable=True) + assert f[2] == StructField("A_STRR", StringType(), nullable=True) + assert f[3] == StructField("NAMER", StringType(), nullable=True) + + # 3. alias col A (LongType) to "NAME" then join with original NAME (StringType). + # The rename must carry A's type (LongType), not original NAME's type. + renamed = df.select(col("A").alias("NAME")) + rj = renamed.alias("L").join( + df.alias("R"), col("L", "NAME") == col("R", "NAME") + ) + assert len(rj.collect()) > 0 + f = rj.schema.fields + assert len(f) == 3 + assert f[0] == StructField("NAMEL", LongType(), nullable=True) + assert f[1] == StructField("A", LongType(), nullable=True) + assert f[2] == StructField("NAMER", StringType(), nullable=True) + + +def test_infer_cast_type_reduces_describe_sas_style_self_join(session): + """Mimic SAS (Snowpark Connect) internal mechanism for self-joins: + + SAS translates SQL like "FROM users e JOIN users m ON ..." by: + 1. session.create_dataframe(..., schema=StructType) -> temp table with Cast projections + 2. session.table(temp_view_name) -> two independent SelectableEntity references + 3. df.select(col(orig).cast(type).alias(plan_id_name)) -> renames with plan-specific IDs + 4. left.join(right, condition).select(...) -> join + projection + + This test reproduces that pattern in pure Snowpark to verify describe reduction. + """ + with patch.object(session, "_cte_optimization_enabled", True): + # Step 1: create a temp table (simulates createDataFrame + createOrReplaceTempView) + base = session.create_dataframe( + [[1, "alice", None], [2, "bob", 1], [3, "charlie", 1]], + schema=StructType( + [ + StructField("USER_ID", LongType()), + StructField("NAME", StringType()), + StructField("MANAGER_ID", LongType()), + ] + ), + ) + temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) + base.write.save_as_table(temp_table_name, mode="overwrite", table_type="temp") + + # Steps 2-4 wrapped in SqlCounter to measure describes during plan building. + # SAS-style: table() -> select(cast.alias) -> join -> select -> collect + # With optimization: 2 describes (vs 5 without) -- Cast inference avoids 3 + expected_describe = 2 if session.reduce_describe_query_enabled else 5 + with SqlCounter(query_count=1, describe_count=expected_describe, join_count=1): + left_ref = session.table(temp_table_name) + right_ref = session.table(temp_table_name) + + left = left_ref.select( + col("USER_ID").cast(LongType()).alias("USER_ID_L"), + col("NAME").cast(StringType()).alias("NAME_L"), + col("MANAGER_ID").cast(LongType()).alias("MANAGER_ID_L"), + ) + right = right_ref.select( + col("USER_ID").cast(LongType()).alias("USER_ID_R"), + col("NAME").cast(StringType()).alias("NAME_R"), + col("MANAGER_ID").cast(LongType()).alias("MANAGER_ID_R"), + ) + + joined = left.join(right, col("MANAGER_ID_L") == col("USER_ID_R")).select( + col("USER_ID_L"), + col("NAME_L"), + col("USER_ID_R"), + col("NAME_R"), + ) + result = joined.collect() + + assert len(result) == 2 + f = joined.schema.fields + assert len(f) == 4 + assert f[0] == StructField("USER_ID_L", LongType(), nullable=True) + assert f[1] == StructField("NAME_L", StringType(), nullable=True) + assert f[2] == StructField("USER_ID_R", LongType(), nullable=True) + assert f[3] == StructField("NAME_R", StringType(), nullable=True) From ba5373ca2bdb3132729b2333b63f14985dd7df00 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 17 Apr 2026 16:17:33 -0700 Subject: [PATCH 9/9] add comment --- src/snowflake/snowpark/_internal/analyzer/metadata_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index 3f4673e866..bd718529a5 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -138,6 +138,12 @@ def _extract_inferable_attribute_names( if attr.datatype: attr = Attribute(attr.name, attr.datatype, attr.nullable) elif isinstance(attr.child, Cast): + # Use the Cast's declared target type for scalar types. + # Structured types (ArrayType, MapType, StructType) are excluded + # because Snowflake may promote their element types at execution + # time (e.g., ArrayType(IntegerType()) -> ArrayType(LongType())), + # so Cast.to may not match the server-returned type. For those, + # we fall through and let a describe query get the actual type. if type(attr.child.to) is not DataType and not isinstance( attr.child.to, (ArrayType, MapType, StructType) ):