diff --git a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py index fd0950dbd8..bd718529a5 100644 --- a/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/metadata_utils.py @@ -1,10 +1,11 @@ # # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +from collections import Counter from enum import Enum from dataclasses import dataclass from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union -from snowflake.snowpark.types import DataType +from snowflake.snowpark.types import ArrayType, DataType, MapType, StructType from snowflake.snowpark._internal.analyzer.expression import ( Attribute, @@ -13,7 +14,7 @@ Star, ) from snowflake.snowpark._internal.analyzer.datatype_mapper import to_sql -from snowflake.snowpark._internal.analyzer.unary_expression import Alias +from snowflake.snowpark._internal.analyzer.unary_expression import Alias, Cast from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( Limit, LogicalPlan, @@ -86,39 +87,111 @@ def infer_quoted_identifiers_from_expressions( def _extract_inferable_attribute_names( attributes: Optional[List[Expression]], + from_attributes: Optional[List[Attribute]] = None, ) -> tuple[Optional[List[Attribute]], Optional[List[Attribute]]]: """ - Returns a list of attribute names that can be infered from a list of Expressions. - Returns None if one or more attributes cannot be infered. + Returns a tuple of (expected_attributes, resolved_attributes) that can be + inferred from a list of Expressions. + - expected_attributes: Attributes that were already resolved (direct column refs) + - resolved_attributes: All attributes in projection order, with types resolved + Returns (None, None) if one or more attributes cannot be inferred. """ if attributes is None: return None, None - new_attributes = [] - old_attributes = [] + # For Alias(Attribute(...), ...), we copy the parent's type from from_attributes by + # quoted name. That is only defined when that name appears once in FROM; if it appears + # more than once (e.g. table-function column overlap), name lookup is ambiguous — in + # that case we fail inference only when we actually need that lookup (see below). + name_counts: Counter[str] = Counter() + from_attr_map: Dict[str, Attribute] = {} + if from_attributes: + name_counts = Counter(a.name for a in from_attributes) + from_attr_map = {a.name: a for a in from_attributes if name_counts[a.name] == 1} + + expected_attributes = [] + resolved_in_order = [] for attr in attributes: - # Attributes are already resolved and don't require inferrence if isinstance(attr, Attribute): - old_attributes.append(attr) + expected_attributes.append(attr) + resolved_in_order.append(attr) continue if isinstance(attr, Alias): # If the first non-aliased child of an Alias node is Literal or Attribute # the column can be inferred. - if isinstance(attr.child, (Literal, Attribute)) and attr.datatype: - attr = Attribute(attr.name, attr.datatype, attr.nullable) + if isinstance(attr.child, Attribute): + # Resolve type from from_attributes by matching child name. + if name_counts[attr.child.name] > 1: + return None, None + if attr.child.name in from_attr_map: + parent = from_attr_map[attr.child.name] + attr = Attribute(attr.name, parent.datatype, parent.nullable) + elif from_attributes is not None: + # FROM schema is known but doesn't contain this column name. + # This shouldn't happen in normal operation — bail out rather + # than proceed with a potentially wrong placeholder type. + return None, None + elif attr.datatype: + attr = Attribute(attr.name, attr.datatype, attr.nullable) + elif isinstance(attr.child, Literal): + if attr.datatype: + attr = Attribute(attr.name, attr.datatype, attr.nullable) + elif isinstance(attr.child, Cast): + # Use the Cast's declared target type for scalar types. + # Structured types (ArrayType, MapType, StructType) are excluded + # because Snowflake may promote their element types at execution + # time (e.g., ArrayType(IntegerType()) -> ArrayType(LongType())), + # so Cast.to may not match the server-returned type. For those, + # we fall through and let a describe query get the actual type. + if type(attr.child.to) is not DataType and not isinstance( + attr.child.to, (ArrayType, MapType, StructType) + ): + attr = Attribute(attr.name, attr.child.to, attr.nullable) elif isinstance(attr, Literal) and type(attr.datatype) != DataType: # Names of literal values can be inferred attr = Attribute( to_sql(attr.value, attr.datatype), attr.datatype, attr.nullable ) - # If the attr has been coerced to attribute then it has been inferred + # If the attr has been coerced to attribute then it has been inferred. if isinstance(attr, Attribute): - new_attributes.append(attr) + resolved_in_order.append(attr) else: return None, None - return old_attributes, new_attributes + # Every item in attributes was resolved; len(resolved_in_order) == len(attributes). + return expected_attributes, resolved_in_order + + +def try_infer_attributes_from_flattened_projection( + final_projection: Optional[List[Expression]], + parent_attributes: Optional[List[Attribute]], +) -> Optional[List[Attribute]]: + """Re-derive attributes after a select() flattening, using parent types. + + When CTE optimization is enabled the ``SelectStatement.select()`` flattening + path must invalidate cached ``_attributes`` because the projection changed. + Rather than unconditionally setting ``_attributes = None`` (which forces a + DESCRIBE query), this helper tries to resolve column types from the parent's + already-known attributes via ``_extract_inferable_attribute_names``. + Only succeeds when every column in the new projection can be fully resolved + to a concrete type (Alias with a resolved Attribute child, or Literal). + Returns ``None`` whenever any column cannot be resolved -- preserving the + existing describe-query behaviour as a safe default. + """ + if final_projection is None or parent_attributes is None: + return None + + _, resolved = _extract_inferable_attribute_names( + final_projection, parent_attributes + ) + + if resolved is not None and all( + isinstance(a, Attribute) and type(a.datatype) is not DataType for a in resolved + ): + return resolved + + return None def _extract_selectable_attributes( @@ -142,28 +215,23 @@ def _extract_selectable_attributes( else: # Get the attributes from the child plan from_attributes = _extract_selectable_attributes(current_plan.from_) - ( - expected_attributes, - new_attributes, - # Extract expected attributes and knowable new attributes - # from current plan - ) = _extract_inferable_attribute_names(current_plan.projection) + (expected_attributes, resolved,) = _extract_inferable_attribute_names( + current_plan.projection, from_attributes + ) # Check that the expected attributes match the attributes from the child plan if ( from_attributes is not None and expected_attributes is not None - and new_attributes is not None + and resolved is not None ): missing_attrs = {attr.name for attr in expected_attributes} - { attr.name for attr in from_attributes } if not missing_attrs and all( - isinstance(attr, (Attribute, Alias)) - # If the attribute datatype is specifically DataType then it is not fully resolved - and type(attr.datatype) is not DataType - for attr in current_plan.projection or [] + isinstance(attr, Attribute) and type(attr.datatype) is not DataType + for attr in resolved ): - attributes = current_plan.projection # type: ignore + attributes = resolved elif ( isinstance( current_plan, diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index f2f8f4a4b2..71368c59d7 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -1460,11 +1460,6 @@ def select(self, cols: List[Expression]) -> "SelectStatement": if can_be_flattened: new = copy(self) final_projection = [] - if ( - self._session.reduce_describe_query_enabled - and self._session.cte_optimization_enabled - ): - new._attributes = None # reset attributes since projection changed assert new_column_states is not None for col, state in new_column_states.items(): if state.change_state in ( @@ -1477,6 +1472,18 @@ def select(self, cols: List[Expression]) -> "SelectStatement": copy(self.column_states[col].expression) ) # add subquery's expression for this column name + if ( + self._session.reduce_describe_query_enabled + and self._session.cte_optimization_enabled + ): + from snowflake.snowpark._internal.analyzer.metadata_utils import ( + try_infer_attributes_from_flattened_projection, + ) + + new._attributes = try_infer_attributes_from_flattened_projection( + final_projection, + self._attributes, + ) new.projection = final_projection new.from_ = self.from_.to_subqueryable() new.pre_actions = new.from_.pre_actions diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 09639160dc..261706d3b1 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -259,7 +259,7 @@ def test_binary(session, type, action): def test_join_with_alias_dataframe(session): expected_describe_count = ( - 3 + 1 if (session.reduce_describe_query_enabled and session.sql_simplifier_enabled) else 4 ) diff --git a/tests/integ/test_eager_schema_validation.py b/tests/integ/test_eager_schema_validation.py index 6f7a18541e..835fc89634 100644 --- a/tests/integ/test_eager_schema_validation.py +++ b/tests/integ/test_eager_schema_validation.py @@ -18,41 +18,39 @@ ) @pytest.mark.parametrize("debug_mode", [True, False]) @pytest.mark.parametrize( - "transform", + "transform, infers_without_debug", [ - pytest.param(lambda x: copy(x), id="copy"), - pytest.param(lambda x: x.to_df(["C", "D"]), id="to_df"), - pytest.param(lambda x: x.distinct(), id="distinct"), - pytest.param(lambda x: x.drop_duplicates(), id="drop_duplicates"), - pytest.param(lambda x: x.limit(1), id="limit"), - pytest.param(lambda x: x.union(x), id="union"), - pytest.param(lambda x: x.union_all(x), id="union_all"), - pytest.param(lambda x: x.union_by_name(x), id="union_by_name"), - pytest.param(lambda x: x.union_all_by_name(x), id="union_all_by_name"), - pytest.param(lambda x: x.intersect(x), id="intersect"), - pytest.param(lambda x: x.natural_join(x), id="natural_join"), - pytest.param(lambda x: x.cross_join(x), id="cross_join"), - pytest.param(lambda x: x.sample(n=1), id="sample"), + pytest.param(lambda x: copy(x), False, id="copy"), + pytest.param(lambda x: x.to_df(["C", "D"]), True, id="to_df"), + pytest.param(lambda x: x.distinct(), False, id="distinct"), + pytest.param(lambda x: x.drop_duplicates(), False, id="drop_duplicates"), + pytest.param(lambda x: x.limit(1), False, id="limit"), + pytest.param(lambda x: x.union(x), False, id="union"), + pytest.param(lambda x: x.union_all(x), False, id="union_all"), + pytest.param(lambda x: x.union_by_name(x), False, id="union_by_name"), + pytest.param(lambda x: x.union_all_by_name(x), False, id="union_all_by_name"), + pytest.param(lambda x: x.intersect(x), False, id="intersect"), + pytest.param(lambda x: x.natural_join(x), False, id="natural_join"), + pytest.param(lambda x: x.cross_join(x), False, id="cross_join"), + pytest.param(lambda x: x.sample(n=1), False, id="sample"), pytest.param( - lambda x: x.with_column_renamed(col("A"), "B"), id="with_column_renamed" + lambda x: x.with_column_renamed(col("A"), "B"), + False, + id="with_column_renamed", ), - # Unpivot already validates names - pytest.param(lambda x: x.unpivot("x", "y", ["A"]), id="unpivot"), - # The following functions do not error early because their schema_query do not contain - # information about the transformation being called. - pytest.param(lambda x: x.drop(col("A")), id="drop"), - pytest.param(lambda x: x.filter(col("A") == lit(1)), id="filter"), - pytest.param(lambda x: x.sort(col("A").desc()), id="sort"), + pytest.param(lambda x: x.unpivot("x", "y", ["A"]), False, id="unpivot"), + pytest.param(lambda x: x.drop(col("A")), False, id="drop"), + pytest.param(lambda x: x.filter(col("A") == lit(1)), False, id="filter"), + pytest.param(lambda x: x.sort(col("A").desc()), False, id="sort"), ], ) -def test_early_attributes(session, transform, debug_mode): +def test_early_attributes(session, transform, infers_without_debug, debug_mode): with patch.object(context, "_debug_eager_schema_validation", debug_mode): df = session.create_dataframe([(1, "A"), (2, "B"), (3, "C")], ["A", "B"]) transformed = transform(df) - # When debug mode is enabled the dataframe plan attributes are populated early - if debug_mode: + if debug_mode or infers_without_debug: assert transformed._plan._metadata.attributes is not None else: assert transformed._plan._metadata.attributes is None diff --git a/tests/integ/test_reduce_describe_query.py b/tests/integ/test_reduce_describe_query.py index 5fe9e033b2..c050214521 100644 --- a/tests/integ/test_reduce_describe_query.py +++ b/tests/integ/test_reduce_describe_query.py @@ -34,7 +34,7 @@ _PYTHON_SNOWPARK_REDUCE_DESCRIBE_QUERY_ENABLED, Session, ) -from snowflake.snowpark.types import LongType, StructField, StructType +from snowflake.snowpark.types import LongType, StringType, StructField, StructType from tests.integ.utils.sql_counter import SqlCounter from tests.utils import IS_IN_STORED_PROC, TestData @@ -493,3 +493,156 @@ def test_update_schema_query_when_attributes_available(session): ) if should_simplify: assert df._plan.schema_query == simplified_schema_query2 + + +def test_infer_cast_type_reduces_describe_for_self_join(session): + """Verify that Alias(Cast) inference reduces describe queries for self-joins + with typed schemas, and that results are correct. + + With reduce_describe_query_enabled=True, the Cast inference avoids redundant + describes for repeated references to the same table (self-join). Without it, + each reference triggers its own describe. + """ + # Without optimization: 5 describes for create + self-join + select + collect + # With optimization: 1 describe (only the initial schema resolution) + expected_describe_count = 1 if session.reduce_describe_query_enabled else 5 + with SqlCounter( + query_count=1, describe_count=expected_describe_count, join_count=1 + ): + df = session.create_dataframe( + [[1, 2], [3, 4]], + schema=StructType( + [StructField("A", LongType()), StructField("B", LongType())] + ), + ) + joined = ( + df.alias("L") + .join(df.alias("R"), col("L", "A") == col("R", "A")) + .select(col("L", "A"), col("R", "B")) + ) + result = joined.collect() + + assert len(result) > 0 + f = joined.schema.fields + assert len(f) == 2 + assert f[0] == StructField("AL", LongType(), nullable=True) + assert f[1] == StructField("BL", LongType(), nullable=True) + + +def test_infer_cast_type_correctness_across_projections(session): + """Verify that Cast/Alias inference produces correct types across + various projection patterns with CTE enabled, exercising the + try_infer_attributes_from_flattened_projection code path. + Each sub-case checks both result correctness and schema types.""" + with patch.object(session, "_cte_optimization_enabled", True): + df = session.create_dataframe( + [[1, "a"], [2, "b"]], + schema=StructType( + [StructField("A", LongType()), StructField("NAME", StringType())] + ), + ) + df2 = session.create_dataframe( + [[3, "c"]], + schema=StructType( + [StructField("A", LongType()), StructField("NAME", StringType())] + ), + ) + + # 1. cast + union: cast changes type to StringType, union preserves it. + cu = df.select(col("A").cast(StringType()).alias("A"), col("NAME")).union_all( + df2.select(col("A").cast(StringType()).alias("A"), col("NAME")) + ) + assert len(cu.collect()) == 3 + assert cu.schema == StructType( + [ + StructField("A", StringType(), nullable=True), + StructField("NAME", StringType(), nullable=True), + ] + ) + + # 2. cast + self-join: Cast inference preserves StringType through join. + casted = df.select(col("A").cast(StringType()).alias("A_STR"), col("NAME")) + cj = casted.alias("L").join( + casted.alias("R"), col("L", "A_STR") == col("R", "A_STR") + ) + assert len(cj.collect()) > 0 + f = cj.schema.fields + assert f[0] == StructField("A_STRL", StringType(), nullable=True) + assert f[1] == StructField("NAMEL", StringType(), nullable=True) + assert f[2] == StructField("A_STRR", StringType(), nullable=True) + assert f[3] == StructField("NAMER", StringType(), nullable=True) + + # 3. alias col A (LongType) to "NAME" then join with original NAME (StringType). + # The rename must carry A's type (LongType), not original NAME's type. + renamed = df.select(col("A").alias("NAME")) + rj = renamed.alias("L").join( + df.alias("R"), col("L", "NAME") == col("R", "NAME") + ) + assert len(rj.collect()) > 0 + f = rj.schema.fields + assert len(f) == 3 + assert f[0] == StructField("NAMEL", LongType(), nullable=True) + assert f[1] == StructField("A", LongType(), nullable=True) + assert f[2] == StructField("NAMER", StringType(), nullable=True) + + +def test_infer_cast_type_reduces_describe_sas_style_self_join(session): + """Mimic SAS (Snowpark Connect) internal mechanism for self-joins: + + SAS translates SQL like "FROM users e JOIN users m ON ..." by: + 1. session.create_dataframe(..., schema=StructType) -> temp table with Cast projections + 2. session.table(temp_view_name) -> two independent SelectableEntity references + 3. df.select(col(orig).cast(type).alias(plan_id_name)) -> renames with plan-specific IDs + 4. left.join(right, condition).select(...) -> join + projection + + This test reproduces that pattern in pure Snowpark to verify describe reduction. + """ + with patch.object(session, "_cte_optimization_enabled", True): + # Step 1: create a temp table (simulates createDataFrame + createOrReplaceTempView) + base = session.create_dataframe( + [[1, "alice", None], [2, "bob", 1], [3, "charlie", 1]], + schema=StructType( + [ + StructField("USER_ID", LongType()), + StructField("NAME", StringType()), + StructField("MANAGER_ID", LongType()), + ] + ), + ) + temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) + base.write.save_as_table(temp_table_name, mode="overwrite", table_type="temp") + + # Steps 2-4 wrapped in SqlCounter to measure describes during plan building. + # SAS-style: table() -> select(cast.alias) -> join -> select -> collect + # With optimization: 2 describes (vs 5 without) -- Cast inference avoids 3 + expected_describe = 2 if session.reduce_describe_query_enabled else 5 + with SqlCounter(query_count=1, describe_count=expected_describe, join_count=1): + left_ref = session.table(temp_table_name) + right_ref = session.table(temp_table_name) + + left = left_ref.select( + col("USER_ID").cast(LongType()).alias("USER_ID_L"), + col("NAME").cast(StringType()).alias("NAME_L"), + col("MANAGER_ID").cast(LongType()).alias("MANAGER_ID_L"), + ) + right = right_ref.select( + col("USER_ID").cast(LongType()).alias("USER_ID_R"), + col("NAME").cast(StringType()).alias("NAME_R"), + col("MANAGER_ID").cast(LongType()).alias("MANAGER_ID_R"), + ) + + joined = left.join(right, col("MANAGER_ID_L") == col("USER_ID_R")).select( + col("USER_ID_L"), + col("NAME_L"), + col("USER_ID_R"), + col("NAME_R"), + ) + result = joined.collect() + + assert len(result) == 2 + f = joined.schema.fields + assert len(f) == 4 + assert f[0] == StructField("USER_ID_L", LongType(), nullable=True) + assert f[1] == StructField("NAME_L", StringType(), nullable=True) + assert f[2] == StructField("USER_ID_R", LongType(), nullable=True) + assert f[3] == StructField("NAME_R", StringType(), nullable=True) diff --git a/tests/unit/test_metadata_utils.py b/tests/unit/test_metadata_utils.py new file mode 100644 index 0000000000..981ea74c09 --- /dev/null +++ b/tests/unit/test_metadata_utils.py @@ -0,0 +1,204 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import pytest + +from snowflake.snowpark._internal.analyzer.expression import ( + Attribute, + Expression, + Literal, +) +from snowflake.snowpark._internal.analyzer.metadata_utils import ( + _extract_inferable_attribute_names, +) +from snowflake.snowpark._internal.analyzer.unary_expression import Alias, Cast +from snowflake.snowpark.types import ( + ArrayType, + DataType, + DoubleType, + IntegerType, + MapType, + StringType, + StructType, +) + + +def test_none_input(): + assert _extract_inferable_attribute_names(None) == (None, None) + + +def test_plain_attribute_passthrough(): + attrs = [Attribute('"A"', IntegerType()), Attribute('"B"', StringType())] + expected, resolved = _extract_inferable_attribute_names(attrs) + assert len(expected) == 2 + assert len(resolved) == 2 + assert resolved[0].name == '"A"' + assert resolved[0].datatype == IntegerType() + assert resolved[1].name == '"B"' + assert resolved[1].datatype == StringType() + + +def test_alias_attribute_resolved_from_parent(): + """Alias(Attribute) resolves type from from_attributes by name.""" + child = Attribute('"A"', DataType()) + projection = [Alias(child, '"X"')] + from_attributes = [Attribute('"A"', IntegerType(), nullable=False)] + + expected, resolved = _extract_inferable_attribute_names(projection, from_attributes) + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].name == '"X"' + assert resolved[0].datatype == IntegerType() + + +def test_alias_attribute_no_match_in_parent(): + """Alias(Attribute) references a name not present in from_attributes — inference fails.""" + child = Attribute('"MISSING"', DataType()) + projection = [Alias(child, '"X"')] + from_attributes = [Attribute('"A"', IntegerType())] + + expected, resolved = _extract_inferable_attribute_names(projection, from_attributes) + assert (expected, resolved) == (None, None) + + +def test_alias_attribute_no_from_attributes_uses_own_datatype(): + """Without from_attributes, Alias(Attribute) resolves via attr.datatype (main compat). + + The resolved Attribute carries a placeholder DataType(); downstream callers + reject it via ``type(attr.datatype) is not DataType``, so no incorrect + metadata is cached — but the function itself does not return (None, None).""" + child = Attribute('"MISSING"', DataType()) + projection = [Alias(child, '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].name == '"X"' + assert type(resolved[0].datatype) is DataType + + +def test_alias_cast_scalar_type(): + """Alias(Cast(to=scalar_type)) resolves to the Cast target type.""" + inner = Attribute('"A"', DataType()) + cast = Cast(inner, IntegerType()) + projection = [Alias(cast, '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].name == '"X"' + assert resolved[0].datatype == IntegerType() + + +@pytest.mark.parametrize( + "structured_type", + [ + ArrayType(IntegerType()), + MapType(StringType(), IntegerType()), + StructType(), + ], + ids=["ArrayType", "MapType", "StructType"], +) +def test_alias_cast_structured_type_skipped(structured_type): + """Alias(Cast(to=structured_type)) is not inferred -- falls back to describe.""" + inner = Attribute('"A"', DataType()) + cast = Cast(inner, structured_type) + projection = [Alias(cast, '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert (expected, resolved) == (None, None) + + +def test_alias_cast_dummy_datatype_skipped(): + """Alias(Cast(to=DataType())) is not inferred -- base DataType is not concrete.""" + inner = Attribute('"A"', DataType()) + cast = Cast(inner, DataType()) + projection = [Alias(cast, '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert (expected, resolved) == (None, None) + + +def test_alias_literal_with_known_type(): + """Alias(Literal) with concrete type resolves (existing behavior).""" + lit = Literal(42, IntegerType()) + projection = [Alias(lit, '"X"')] + projection[0].datatype = IntegerType() + + expected, resolved = _extract_inferable_attribute_names(projection) + assert resolved is not None + assert resolved[0].name == '"X"' + assert resolved[0].datatype == IntegerType() + + +def test_unresolvable_expression(): + """Alias with unknown child type -> (None, None).""" + + class UnknownExpr(Expression): + pass + + projection = [Alias(UnknownExpr(), '"X"')] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert (expected, resolved) == (None, None) + + +def test_mixed_resolvable_and_unresolvable(): + """If any expression is unresolvable, entire result is (None, None).""" + good = Attribute('"A"', IntegerType()) + + class UnknownExpr(Expression): + pass + + bad = Alias(UnknownExpr(), '"B"') + projection = [good, bad] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert (expected, resolved) == (None, None) + + +def test_multiple_cast_types(): + """Multiple Alias(Cast) with different scalar types all resolve.""" + projection = [ + Alias(Cast(Attribute('"A"', DataType()), IntegerType()), '"X"'), + Alias(Cast(Attribute('"B"', DataType()), StringType()), '"Y"'), + Alias(Cast(Attribute('"C"', DataType()), DoubleType()), '"Z"'), + ] + + expected, resolved = _extract_inferable_attribute_names(projection) + assert resolved is not None + assert len(resolved) == 3 + assert resolved[0].datatype == IntegerType() + assert resolved[1].datatype == StringType() + assert resolved[2].datatype == DoubleType() + + +def test_duplicate_from_attributes_plain_projection_still_resolves(): + """Duplicate FROM names do not block inference when projection only uses plain + Attributes (no ambiguous name-based parent lookup for Alias).""" + projection = [Attribute('"A"', IntegerType())] + from_attributes = [ + Attribute('"X"', IntegerType()), + Attribute('"X"', StringType()), + Attribute('"A"', StringType()), + ] + + expected, resolved = _extract_inferable_attribute_names(projection, from_attributes) + assert resolved is not None + assert len(resolved) == 1 + assert resolved[0].name == '"A"' + + +def test_duplicate_from_attributes_alias_child_ambiguous_returns_none(): + """Alias(Attribute) cannot inherit type by name when that name is duplicated in FROM.""" + projection = [ + Alias(Attribute('"X"', DataType()), '"Y"'), + ] + from_attributes = [ + Attribute('"X"', IntegerType()), + Attribute('"X"', StringType()), + ] + + expected, resolved = _extract_inferable_attribute_names(projection, from_attributes) + assert (expected, resolved) == (None, None)