diff --git a/CHANGELOG.md b/CHANGELOG.md index f21646bae7..fe8e9090bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ - Fixed a bug where using parameter bindings for `CALL` queries issued through `session.sql` would raise an error. - Fixed a bug where `StringType` columns from Iceberg tables were not recognized as max-size strings. +#### Improvements + +- When `Session.reduce_describe_query_enabled` is enabled, fewer DESCRIBE queries are issued when the outer query only projects or renames columns from an inner subquery whose column types are already known. + ## 1.50.0 (2026-04-23) ### Snowpark Python API Updates diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index 028dfdec72..f4ea001917 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -85,8 +85,8 @@ has_invalid_projection_merge_functions, ) from snowflake.snowpark._internal.utils import ( - is_sql_select_statement, ExprAliasUpdateDict, + is_sql_select_statement, ) import snowflake.snowpark.context as context @@ -1592,6 +1592,70 @@ def select(self, cols: List[Expression]) -> "SelectStatement": ) ) + # When describe reduction is on and the inner select already has resolved + # attributes, infer new.attributes for this outer select by reusing datatype and + # nullable from the subquery: (0) skip if parent column names collide, (1) index + # attributes by exact parent Attribute.name, (2) walk new.projection, (3) only + # handle plain columns or Alias(column), (4) resolve source by exact string match + # of the projection source name to that name (no quote_name / normalization), + # (5) assign only if every output column was inferred (length matches projection). + if self._session.reduce_describe_query_enabled and self.attributes is not None: + parent_attributes = self.attributes + projection = new.projection + inferred_attributes: Optional[List[Attribute]] = None + # Skip: no projection to walk (do not assert; leave new.attributes unchanged). + if projection is not None: + # Skip: duplicate output names on the parent — dict/lookup would be ambiguous. + attributes_by_column_name: Dict[str, Attribute] = {} + collision = False + for attr in parent_attributes: + key = attr.name + existing = attributes_by_column_name.get(key) + # Skip: two parent columns share the same name string. + if existing is not None and existing is not attr: + collision = True + break + attributes_by_column_name[key] = attr + if not collision: + inferred_attributes = [] + for expr in projection: + source_column_name: Optional[str] = None + projected_column_name: Optional[str] = None + if isinstance(expr, (Attribute, UnresolvedAttribute)): + source_column_name = expr.name + projected_column_name = expr.name + elif isinstance(expr, Alias) and isinstance( + expr.child, (Attribute, UnresolvedAttribute) + ): + source_column_name = expr.child.name + projected_column_name = expr.name + else: + # Skip: not a plain column or Alias(Attribute|UnresolvedAttribute). + inferred_attributes = [] + break + + if source_column_name is None or projected_column_name is None: + # Skip: missing projected output name. + inferred_attributes = [] + break + source_attr = attributes_by_column_name.get(source_column_name) + # Skip: no parent column for this source name. + if source_attr is None: + inferred_attributes = [] + break + inferred_attributes.append( + Attribute( + projected_column_name, + source_attr.datatype, + source_attr.nullable, + ) + ) + if len(inferred_attributes) != len(projection): + # Skip: incomplete inference (includes defensive mismatch). + inferred_attributes = None + if inferred_attributes is not None: + new.attributes = inferred_attributes + new.flatten_disabled = disable_next_level_flatten assert new.projection is not None new._column_states = derive_column_states_from_subquery( diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 09639160dc..bfa9e4cddc 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -259,7 +259,7 @@ def test_binary(session, type, action): def test_join_with_alias_dataframe(session): expected_describe_count = ( - 3 + 2 if (session.reduce_describe_query_enabled and session.sql_simplifier_enabled) else 4 ) diff --git a/tests/integ/test_reduce_describe_query.py b/tests/integ/test_reduce_describe_query.py index 5fe9e033b2..42842b3212 100644 --- a/tests/integ/test_reduce_describe_query.py +++ b/tests/integ/test_reduce_describe_query.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # -from typing import List +from typing import Dict, List import copy import pytest @@ -10,7 +10,7 @@ from unittest.mock import patch -from snowflake.snowpark import DataFrame +from snowflake.snowpark import DataFrame, Row from snowflake.snowpark._internal.analyzer.expression import Attribute, Star from snowflake.snowpark._internal.analyzer.unary_expression import UnresolvedAlias from snowflake.snowpark._internal.analyzer.unary_plan_node import Project @@ -19,6 +19,7 @@ TempObjectType, random_name_for_temp_object, ) +from snowflake.snowpark.exceptions import SnowparkPlanException from snowflake.snowpark.functions import ( avg, col, @@ -34,9 +35,9 @@ _PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, Session, ) -from snowflake.snowpark.types import LongType, StructField, StructType +from snowflake.snowpark.types import LongType, StringType, StructField, StructType from tests.integ.utils.sql_counter import SqlCounter -from tests.utils import IS_IN_STORED_PROC, TestData +from tests.utils import IS_IN_STORED_PROC, TestData, Utils pytestmark = [ pytest.mark.skipif( @@ -228,6 +229,10 @@ def check_attributes_equality(attrs1: List[Attribute], attrs2: List[Attribute]) assert attr1.nullable == attr2.nullable +def _attrs_by_name(parent_attributes: List[Attribute]) -> Dict[str, Attribute]: + return {attr.name: attr for attr in parent_attributes} + + def has_star_in_projection(df: DataFrame) -> bool: plan = df._plan.source_plan return isinstance(plan, Project) and any( @@ -421,6 +426,273 @@ def test_cache_metadata_on_selectable_entity(session): _ = df.col("a") +def test_project_alias_infers_attributes_from_parent_metadata(session): + df = session.create_dataframe(["v"], schema=["c"]) + _ = df.schema + + parent_attributes = df._plan._metadata.attributes + assert parent_attributes is not None + expected_attributes = [parent_attributes[0].with_name("a2")] + + df2 = df.select(col("c").alias("a2")) + if session.reduce_describe_query_enabled: + check_attributes_equality(df2._plan._metadata.attributes, expected_attributes) + expected_describe_count = 0 + else: + assert df2._plan._metadata.attributes is None + expected_describe_count = 1 + + with SqlCounter(query_count=0, describe_count=expected_describe_count): + check_attributes_equality(df2._plan.attributes, expected_attributes) + + +def test_swap_column_aliases_infers_types_from_source_names(session): + """n-1: columns a, b; n: b AS a, a AS b — metadata follows source column types.""" + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) + _ = df.schema + + parent_attributes = df._plan._metadata.attributes + assert parent_attributes is not None + by_name = _attrs_by_name(parent_attributes) + expected_attributes = [ + by_name['"B"'].with_name("a"), + by_name['"A"'].with_name("b"), + ] + + df2 = df.select(col("b").alias("a"), col("a").alias("b")) + Utils.check_answer(df2, [Row(A=2, B=1)], sort=False) + + if session.reduce_describe_query_enabled: + check_attributes_equality(df2._plan._metadata.attributes, expected_attributes) + expected_describe_count = 0 + else: + assert df2._plan._metadata.attributes is None + expected_describe_count = 1 + + with SqlCounter(query_count=0, describe_count=expected_describe_count): + check_attributes_equality(df2._plan.attributes, expected_attributes) + + +def test_swap_mixed_column_types_inference_follows_source(session): + """Swapped output names must take datatypes from the referenced source column.""" + schema = StructType( + [ + StructField("a", LongType(), nullable=True), + StructField("b", StringType(), nullable=True), + ] + ) + df = session.create_dataframe([(1, "x")], schema=schema) + _ = df.schema + + parent_attributes = df._plan._metadata.attributes + assert parent_attributes is not None + by_name = _attrs_by_name(parent_attributes) + expected_attributes = [ + by_name['"B"'].with_name("a"), + by_name['"A"'].with_name("b"), + ] + + df2 = df.select(col("b").alias("a"), col("a").alias("b")) + Utils.check_answer(df2, [Row(A="x", B=1)], sort=False) + + if session.reduce_describe_query_enabled: + check_attributes_equality(df2._plan._metadata.attributes, expected_attributes) + expected_describe_count = 0 + else: + assert df2._plan._metadata.attributes is None + expected_describe_count = 1 + + with SqlCounter(query_count=0, describe_count=expected_describe_count): + check_attributes_equality(df2._plan.attributes, expected_attributes) + + +def test_column_permutation_inference_name_keyed_lookup(session): + """Non-swap rename: c->a, a->x, b->y — each output type matches its source column.""" + df = session.create_dataframe([[1, 2, 3]], schema=["a", "b", "c"]) + _ = df.schema + + parent_attributes = df._plan._metadata.attributes + assert parent_attributes is not None + by_name = _attrs_by_name(parent_attributes) + expected_attributes = [ + by_name['"C"'].with_name("a"), + by_name['"A"'].with_name("x"), + by_name['"B"'].with_name("y"), + ] + + df2 = df.select(col("c").alias("a"), col("a").alias("x"), col("b").alias("y")) + Utils.check_answer(df2, [Row(A=3, X=1, Y=2)], sort=False) + + if session.reduce_describe_query_enabled: + check_attributes_equality(df2._plan._metadata.attributes, expected_attributes) + expected_describe_count = 0 + else: + assert df2._plan._metadata.attributes is None + expected_describe_count = 1 + + with SqlCounter(query_count=0, describe_count=expected_describe_count): + check_attributes_equality(df2._plan.attributes, expected_attributes) + + +def test_chained_simple_renames_infer_from_previous_metadata(session): + """Second select's parent already has inferred attributes from the first rename.""" + df = session.create_dataframe([[10, 20]], schema=["a", "b"]) + _ = df.schema + + df1 = df.select(col("a").alias("p"), col("b").alias("q")) + if session.reduce_describe_query_enabled: + assert df1._plan._metadata.attributes is not None + mid_attrs = df1._plan._metadata.attributes + assert mid_attrs is not None or not session.reduce_describe_query_enabled + + df2 = df1.select(col("p").alias("x"), col("q").alias("y")) + _ = df1.schema + + if session.reduce_describe_query_enabled: + assert df2._plan._metadata.attributes is not None + by_mid = _attrs_by_name(df1._plan._metadata.attributes or []) + expected_attributes = [ + by_mid['"P"'].with_name("x"), + by_mid['"Q"'].with_name("y"), + ] + check_attributes_equality(df2._plan._metadata.attributes, expected_attributes) + with SqlCounter(query_count=0, describe_count=0): + check_attributes_equality(df2._plan.attributes, expected_attributes) + else: + assert df2._plan._metadata.attributes is None + with SqlCounter(query_count=0, describe_count=1): + _ = df2._plan.attributes + + +def test_quoted_case_sensitive_sql_column_metadata_inference(session): + """Delimited identifier from session.sql: chained select infers metadata without DESCRIBE.""" + df = session.sql('SELECT 1 AS "MixedCase"') + with SqlCounter(query_count=0, describe_count=1, strict=False): + _ = df.schema + + df2 = df.select(col('"MixedCase"')) + if session.reduce_describe_query_enabled: + assert df2._plan._metadata.attributes is not None + assert len(df2._plan._metadata.attributes) == 1 + assert df2._plan._metadata.attributes[0].name == '"MixedCase"' + + expected_describe = 0 if session.reduce_describe_query_enabled else 1 + with SqlCounter(query_count=0, describe_count=expected_describe): + attrs = df2._plan.attributes + assert attrs is not None + assert len(attrs) == 1 + assert attrs[0].name == '"MixedCase"' + + +def test_non_simple_projection_skips_metadata_inference(session): + """Expressions other than plain column or simple alias(column) do not infer attributes.""" + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) + _ = df.schema + + df2 = df.select((col("a") + lit(1)).alias("ap1"), "b") + + assert df2._plan._metadata.attributes is None + + with SqlCounter(query_count=0, describe_count=1): + _ = df2._plan.attributes + + df3 = df.select(col("a"), (col("b") + lit(1)).alias("b")) + assert df3._plan._metadata.attributes is None + with SqlCounter(query_count=0, describe_count=1): + _ = df3._plan.attributes + + +def test_mixed_simple_column_and_literal_alias_still_requires_describe(session): + """Alias(Literal) is not a simple rename; inference aborts even when the first column is plain.""" + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) + _ = df.schema + + df2 = df.select("a", lit(1).alias("c")) + assert df2._plan._metadata.attributes is None + + with SqlCounter(query_count=0, describe_count=1): + _ = df2._plan.attributes + + +def test_simple_column_then_complex_expression_no_partial_metadata(session): + """First column is inferable but second is not; all-or-nothing — no partial cached attributes.""" + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) + _ = df.schema + + df2 = df.select("a", (col("b") + lit(1)).alias("b2")) + assert df2._plan._metadata.attributes is None + + with SqlCounter(query_count=0, describe_count=1): + _ = df2._plan.attributes + + +def test_cast_on_column_alias_still_requires_describe(session): + """Alias(Cast(...)) is not Alias(Attribute); types cannot be copied from the subquery without DESCRIBE.""" + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) + _ = df.schema + + df2 = df.select(col("a").cast(LongType()).alias("a")) + assert df2._plan._metadata.attributes is None + + with SqlCounter(query_count=0, describe_count=1): + _ = df2._plan.attributes + + +def test_select_inference_skips_on_duplicate_parent_keys_and_missing_alias_name( + session, +): + """SelectStatement.select: (1) duplicate parent output aliases — collision skips inference + on a follow-up select; DESCRIBE is not skipped when resolving schema for the duplicate-alias + frame. (2) Alias with missing output name — defensive inference abort.""" + df = session.create_dataframe([[1, 2, 3]], schema=["a", "b", "c"]) + _ = df.schema + dup = df.select((col("a") + 1).as_("b"), (col("c") + 1).as_("b")) + with SqlCounter(query_count=0, describe_count=1): + _ = dup.schema + + dup_outer = dup.select(lit(1).alias("x")) + _ = dup_outer._plan.attributes + + # Scenario B: hit missing-projected-name guard without DataFrame.resolve (which would + # quote_name(None) on the Alias). Call SelectStatement.select directly. + df2 = session.create_dataframe([[1]], schema=["a"]) + _ = df2.schema + bad = col("a").alias("out") + object.__setattr__(bad._expression, "name", None) + inner = df2._select_statement + new_ss = inner.select([bad._named()]) + assert new_ss.attributes is None + + +def test_reduce_describe_inference_invalid_delimited_identifier_rejected(session): + """Malformed delimited identifiers are rejected by plan analysis (error 1200), not coerced.""" + df = session.create_dataframe([[1]], schema=["x"]) + _ = df.schema + for bad_col in (r'"col""', r'"ab"c"', r'""col"'): + with pytest.raises(SnowparkPlanException) as ex_info: + df.select(col(bad_col)).collect() + assert ex_info.value.error_code == "1200" + assert "Invalid identifier" in str(ex_info.value) + + +def test_select_star_after_cached_parent(session): + """SELECT * after parent schema is cached: infer_metadata can copy child attributes when reduce_describe is on.""" + df = session.create_dataframe([[1, 2]], schema=["a", "b"]) + _ = df.schema + parent_attrs = df._plan._metadata.attributes + assert parent_attrs is not None + + df2 = df.select("*") + if session.reduce_describe_query_enabled: + assert df2._plan._metadata.attributes is not None + check_attributes_equality(df2._plan._metadata.attributes, parent_attrs) + else: + assert df2._plan._metadata.attributes is None + + # Resolving attributes must match the logical schema (DESCRIBE may run when reduce is off). + check_attributes_equality(df2._plan.attributes, parent_attrs) + + @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Can't create a session in SP") def test_reduce_describe_query_enabled_on_session(db_parameters): with Session.builder.configs(db_parameters).create() as new_session: