diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index e8287140a7..1c32af1f83 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -2,7 +2,8 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # -from collections import defaultdict +import hashlib +from collections import Counter, defaultdict from typing import Dict, List, Optional, Set from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan @@ -117,18 +118,30 @@ def _replace_duplicate_node_with_cte( Replace all duplicated nodes with a WithQueryBlock (CTE node), to enable query generation with CTEs. + In connect-compatible mode, per-object occurrence counts are built + during the traversal so that only nodes whose specific Python object + appears more than once are CTE-replaced. Different objects that + share the same ``encoded_node_id_with_query`` but each appear only + once are left as inline subqueries. + NOTE, we use stack to perform a post-order traversal instead of recursive call. The reason of using the stack approach is that chained CTEs have to be built from bottom (innermost subquery) to top (outermost query). This function uses an iterative approach to avoid hitting Python's maximum recursion depth limit. """ + use_per_object = context._is_snowpark_connect_compatible_mode node_parents_map: Dict[TreeNode, Set[TreeNode]] = defaultdict(set) + # Per-object occurrence counts, built during the traversal below. + # Only populated in connect-compatible mode. + object_id_counts: Counter = Counter() stack1, stack2 = [root], [] while stack1: node = stack1.pop() stack2.append(node) + if use_per_object: + object_id_counts[id(node)] += 1 for child in reversed(node.children_plan_nodes): node_parents_map[child].add(node) stack1.append(child) @@ -136,7 +149,8 @@ def _replace_duplicate_node_with_cte( # track node that is already visited to avoid repeated operation on the same node visited_nodes: Set[TreeNode] = set() updated_nodes: Set[TreeNode] = set() - # track the resolved WithQueryBlock node has been created for each duplicated node + # track the resolved WithQueryBlock node has been created for each duplicated node. + # In per-object mode the key is str(id(node)); otherwise it is encoded_node_id_with_query. resolved_with_block_map: Dict[str, SnowflakePlan] = {} def _update_parents( @@ -162,30 +176,44 @@ def _update_parents( # if the node is a duplicated node and deduplication is not done for the node, # start the deduplication transformation use CTE if node.encoded_node_id_with_query in duplicated_node_ids: - if node.encoded_node_id_with_query in resolved_with_block_map: - # if the corresponding CTE block has been created, use the existing - # one. - resolved_with_block = resolved_with_block_map[ - node.encoded_node_id_with_query - ] + if use_per_object and object_id_counts[id(node)] <= 1: + visited_nodes.add(node) + if node in updated_nodes: + _update_parents(node, should_replace_child=False) + continue + + map_key = ( + node.encoded_node_id_with_query + if not use_per_object + else str(id(node)) + ) + if map_key in resolved_with_block_map: + resolved_with_block = resolved_with_block_map[map_key] else: if ( self._query_generator.session.reduce_describe_query_enabled and context._is_snowpark_connect_compatible_mode ): - # create a deterministic name using the first 10 chars of encoded_node_id_with_query (SHA256 hash) - # It helps when DataFrame.queries is called multiple times. - # Consistent CTE names returned, reducing the number of describe queries from cached_analyze_attributes calls. - cte_name = f"{TEMP_OBJECT_NAME_PREFIX}{TempObjectType.CTE.value}_{node.encoded_node_id_with_query[:HASH_LENGTH].upper()}" + if use_per_object: + obj_hash = ( + hashlib.sha256( + f"{node.encoded_node_id_with_query}:{id(node)}".encode() + ) + .hexdigest()[:HASH_LENGTH] + .upper() + ) + else: + obj_hash = node.encoded_node_id_with_query[ + :HASH_LENGTH + ].upper() + cte_name = f"{TEMP_OBJECT_NAME_PREFIX}{TempObjectType.CTE.value}_{obj_hash}" else: cte_name = random_name_for_temp_object(TempObjectType.CTE) with_block = WithQueryBlock(name=cte_name, child=node) # type: ignore with_block._is_valid_for_replacement = True resolved_with_block = self._query_generator.resolve(with_block) - resolved_with_block_map[ - node.encoded_node_id_with_query - ] = resolved_with_block + resolved_with_block_map[map_key] = resolved_with_block self._total_number_ctes += 1 _update_parents( node, should_replace_child=True, new_child=resolved_with_block diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 1adcbe3558..b82f53d04c 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -5,6 +5,7 @@ import re import tracemalloc from unittest import mock +import uuid import pytest @@ -60,8 +61,17 @@ WITH = "WITH" +@pytest.fixture(params=[False, True], ids=["connect_mode_off", "connect_mode_on"]) +def is_connect_mode(request): + """Parametrize every test over _is_snowpark_connect_compatible_mode.""" + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", request.param + ): + yield request.param + + @pytest.fixture(autouse=True) -def setup(request, session): +def setup(request, session, is_connect_mode): is_cte_optimization_enabled = session._cte_optimization_enabled is_query_compilation_enabled = session._query_compilation_stage_enabled session._query_compilation_stage_enabled = True @@ -251,7 +261,9 @@ def test_binary(session, type, action): assert len(plan_queries["post_actions"]) == 1 -def test_join_with_alias_dataframe(session): +def test_join_with_alias_dataframe(session, is_connect_mode): + c1 = f"col1_{uuid.uuid4().hex[:8]}" + c2 = f"col2_{uuid.uuid4().hex[:8]}" expected_describe_count = ( 3 if (session.reduce_describe_query_enabled and session.sql_simplifier_enabled) @@ -260,11 +272,11 @@ def test_join_with_alias_dataframe(session): with SqlCounter( query_count=2, describe_count=expected_describe_count, join_count=2 ): - df1 = session.create_dataframe([[1, 6]], schema=["col1", "col2"]) + df1 = session.create_dataframe([[1, 6]], schema=[c1, c2]) df_res = ( df1.alias("L") - .join(df1.alias("R"), col("L", "col1") == col("R", "col1")) - .select(col("L", "col1"), col("R", "col2")) + .join(df1.alias("R"), col("L", c1) == col("R", c1)) + .select(col("L", c1), col("R", c2)) ) session._cte_optimization_enabled = False @@ -355,7 +367,7 @@ def test_join_with_set_operation(session): @pytest.mark.parametrize("type, action", binary_operations) -def test_variable_binding_binary(session, type, action): +def test_variable_binding_binary(session, type, action, is_connect_mode): df1 = session.sql( "select $1 as a, $2 as b from values (?, ?), (?, ?)", params=[1, "a", 2, "b"] ) @@ -372,10 +384,12 @@ def test_variable_binding_binary(session, type, action): join_count = 1 if type == "union": union_count = 1 + # df1 and df3 are different Python objects with the same SQL. + # In connect mode they should NOT be deduplicated. check_result( session, action(df1, df3), - expect_cte_optimized=True, + expect_cte_optimized=not is_connect_mode, query_count=1, describe_count=0, union_count=union_count, @@ -551,23 +565,85 @@ def test_number_of_ctes(session, type, action): ) -def test_different_df_same_query(session): +def test_different_df_same_query(session, is_connect_mode): df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") df = df2.union_all(df1) + # df1 and df2 are different Python objects with the same SQL. + # In connect mode they should NOT be deduplicated. check_result( session, df, - expect_cte_optimized=True, + expect_cte_optimized=not is_connect_mode, query_count=1, describe_count=0, union_count=1, join_count=0, ) + with SqlCounter(query_count=0, describe_count=0): + expected_cte_count = 0 if is_connect_mode else 1 + assert count_number_of_ctes(df.queries["queries"][-1]) == expected_cte_count + + +def test_mixed_duplicated_and_unique_objects_same_sql(session, is_connect_mode): + """ + union(union(df1, df1), union(df2, df3)) where df1, df2, df3 all produce + identical SQL but are different Python objects. + + In connect mode: + - df1 appears twice (same object) -> should be CTE-deduplicated + - df2 and df3 each appear once -> should NOT be CTE-deduplicated + - Expect 1 CTE (for df1 only) + In non-connect mode: + - All share the same encoded ID -> treated as one CTE + - Expect 1 CTE (all unified) + """ + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df3 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df = df1.union_all(df1).union_all(df2.union_all(df3)) + check_result( + session, + df, + expect_cte_optimized=True, + query_count=1, + describe_count=0, + union_count=3, + join_count=0, + ) with SqlCounter(query_count=0, describe_count=0): assert count_number_of_ctes(df.queries["queries"][-1]) == 1 +def test_distinct_objects_each_duplicated(session, is_connect_mode): + """ + union(union(df1, df1), union(df2, df2)) where df1 and df2 produce + identical SQL but are different Python objects. + + In connect mode: + - df1 appears twice, df2 appears twice -> each gets its own CTE + - Expect 2 CTEs + In non-connect mode: + - All share the same encoded ID -> one CTE + - Expect 1 CTE + """ + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df = df1.union_all(df1).union_all(df2.union_all(df2)) + check_result( + session, + df, + expect_cte_optimized=True, + query_count=1, + describe_count=0, + union_count=3, + join_count=0, + ) + with SqlCounter(query_count=0, describe_count=0): + expected_cte_count = 2 if is_connect_mode else 1 + assert count_number_of_ctes(df.queries["queries"][-1]) == expected_cte_count + + def test_same_duplicate_subtree(session): """ root @@ -624,7 +700,7 @@ def test_same_duplicate_subtree(session): @pytest.mark.parametrize("use_different_df", [True, False]) -def test_cte_preserves_join_suffix_aliases(session, use_different_df): +def test_cte_preserves_join_suffix_aliases(session, use_different_df, is_connect_mode): df_ad_group = session.create_dataframe( [["1048771", "group_1", "campaign_1"]], schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"], @@ -695,8 +771,14 @@ def test_cte_preserves_join_suffix_aliases(session, use_different_df): assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql # when using different df_ad_group with disambiguation, because rsuffix in join, # they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE - # However there is still a CTE for create_dataframe call - assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1 + # However there is still a CTE for create_dataframe call. + # In connect mode with use_different_df, all create_dataframe calls are + # distinct objects so no CTEs are produced. + if is_connect_mode and use_different_df: + expected_cte_count = 0 + else: + expected_cte_count = 1 + assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == expected_cte_count @pytest.mark.parametrize( @@ -807,7 +889,7 @@ def test_explain(session): assert "WITH SNOWPARK_TEMP_CTE" in explain_string -def test_sql_simplifier(session): +def test_sql_simplifier(session, is_connect_mode): if not session._sql_simplifier_enabled: pytest.skip("SQL simplifier is not enabled") @@ -822,6 +904,9 @@ def test_sql_simplifier(session): df2 = df1.select("a", "b") df3 = df1.select("a", "b").select("a", "b") df4 = df1.union_by_name(df2).union_by_name(df3) + # df1, df2, df3 are different Python objects that simplify to the same SQL. + # In connect mode they are not deduplicated, but df (create_dataframe) is + # still the same object appearing across all branches → still CTE'd. check_result( session, df4, @@ -832,11 +917,35 @@ def test_sql_simplifier(session): join_count=0, ) with SqlCounter(query_count=0, describe_count=0): - # after applying sql simplifier, there is only one CTE (df1, df2, df3 have the same query) - assert ( - count_number_of_ctes(Utils.normalize_sql(df4.queries["queries"][-1])) == 1 - ) - assert Utils.normalize_sql(df4.queries["queries"][-1]).count(filter_clause) == 1 + if is_connect_mode: + # df1, df2, df3 are different objects → not merged. + # Only df (create_dataframe) is the same object across all branches → 1 CTE. + # Generated SQL: + # WITH CTE AS (SELECT $1 AS "A", $2 AS "B" FROM VALUES ...) + # (SELECT "A","B" FROM (CTE) WHERE ("A"=1)) + # UNION (SELECT "A","B" FROM (CTE) WHERE ("A"=1)) + # UNION (SELECT "A","B" FROM (CTE) WHERE ("A"=1)) + assert ( + count_number_of_ctes(Utils.normalize_sql(df4.queries["queries"][-1])) + == 1 + ) + assert ( + Utils.normalize_sql(df4.queries["queries"][-1]).count(filter_clause) + == 3 + ) + else: + # df1, df2, df3 all simplify to the same SQL and are merged into 1 CTE. + # Generated SQL: + # WITH CTE AS (SELECT "A","B" FROM (VALUES ...) WHERE ("A"=1)) + # (CTE) UNION (CTE) UNION (CTE) + assert ( + count_number_of_ctes(Utils.normalize_sql(df4.queries["queries"][-1])) + == 1 + ) + assert ( + Utils.normalize_sql(df4.queries["queries"][-1]).count(filter_clause) + == 1 + ) df5 = df1.join(df2).join(df3) check_result( @@ -988,18 +1097,20 @@ def test_sql_non_select(session): ) -def test_sql_with(session): +def test_sql_with(session, is_connect_mode): df1 = session.sql("with t as (select 1 as A) select * from t") df2 = session.sql("with t as (select 1 as A) select * from t") df_result = df1.union(df2).select("A").filter(lit(True)) + # df1 and df2 are different Python objects with the same SQL. + # In connect mode they should NOT be deduplicated. check_result( session, df_result, # with ... select is also treated as a select query # see is_sql_select_statement() function - expect_cte_optimized=True, + expect_cte_optimized=not is_connect_mode, query_count=1, describe_count=0, union_count=1, @@ -1325,7 +1436,7 @@ def test_table_select_cte(session): ], ) def test_dataframe_queries_with_cte_reuses_schema_cache( - session, reduce_describe_enabled, expected_describe_counts + session, reduce_describe_enabled, expected_describe_counts, is_connect_mode ): """Test that calling dataframe.queries (not same dataframe but same operation) multiple times with CTE optimization does not issue extra DESCRIBE queries when reduce_describe_query_enabled is True. @@ -1335,9 +1446,13 @@ def test_dataframe_queries_with_cte_reuses_schema_cache( identical SQL (with same CTE names), allowing the schema cache to hit. """ + # randomize column names to avoid schema cache hits from prior test runs in the same session. + col_a = f"col_{uuid.uuid4().hex[:8]}" + col_b = f"col_{uuid.uuid4().hex[:8]}" + def create_cte_dataframe(): """Create a DataFrame that triggers CTE optimization (same df used twice).""" - df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df = session.create_dataframe([[1, 2], [3, 4]], schema=[col_a, col_b]) return df.union_all(df) def access_queries_and_schema(df): @@ -1347,7 +1462,7 @@ def access_queries_and_schema(df): with mock.patch.object( session, "_reduce_describe_query_enabled", reduce_describe_enabled - ), mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + ): for expected_describe_count in expected_describe_counts: df_union = create_cte_dataframe() with SqlCounter(query_count=0, describe_count=expected_describe_count): diff --git a/tests/integ/test_cte_connect_mode_dedup.py b/tests/integ/test_cte_connect_mode_dedup.py new file mode 100644 index 0000000000..ab23f710ee --- /dev/null +++ b/tests/integ/test_cte_connect_mode_dedup.py @@ -0,0 +1,441 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""CTE deduplication tests for _is_snowpark_connect_compatible_mode. + +When connect-compatible mode is enabled, two different DataFrame objects that +produce the same SQL should NOT be merged into a single CTE. Only the same +Python object referenced multiple times (e.g. df.union_all(df)) should be +deduplicated. + +This prevents incorrect results when non-deterministic functions like +uuid_string() are used: df1.union_all(df2) should produce two independent +evaluations, not a single CTE referenced twice. + +Tests cover: +1. Union with two distinct DFs (no CTE in connect mode) +2. Union with same DF ref (CTE still applies in connect mode) +3. Join-based triggers with distinct DFs +4. Chained operations producing imbalanced subtrees +""" + +import copy +from unittest import mock + +import pytest + +import snowflake.snowpark.context as context +from snowflake.snowpark.functions import col, lit, random, uuid_string +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) + +pytestmark = [ + pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="CTE is a SQL feature", + run=False, + ), +] + + +@pytest.fixture(scope="module") +def test_table_name(session): + """Create a shared test table with columns a (INT) and b (INT).""" + name = random_name_for_temp_object(TempObjectType.TABLE) + session.sql( + f""" + CREATE OR REPLACE TEMP TABLE {name} (a INT, b INT) + """ + ).collect() + session.sql( + f""" + INSERT INTO {name} VALUES (1, 2), (3, 4), (5, 6), (7, 8), (9, 10) + """ + ).collect() + yield name + session.sql(f"DROP TABLE IF EXISTS {name}").collect() + + +@pytest.fixture(autouse=True) +def enable_connect_compatible_mode(): + """Patch _is_snowpark_connect_compatible_mode to True for all tests in this module.""" + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + yield + + +def _get_query_cte_off_and_on(session, df): + """Get the last generated query with CTE off and CTE on using mock.patch.object.""" + with mock.patch.object(session, "_cte_optimization_enabled", False): + query_off = df.queries["queries"][-1] + with mock.patch.object(session, "_cte_optimization_enabled", True): + query_on = df.queries["queries"][-1] + return query_off, query_on + + +def _collect_uuid_halves(df): + """Collect a union DataFrame and return (top_half_uuids, bottom_half_uuids).""" + result = df.collect() + half = len(result) // 2 + assert half > 0, "Need at least 2 rows to compare halves" + top = [row["U"] for row in result[:half]] + bot = [row["U"] for row in result[half:]] + return top, bot + + +def assert_cte_sql_shape( + query_off: str, query_on: str, expect_cte: bool = True +) -> None: + """Assert that generated SQL has the expected CTE shape (no result comparison). + + Args: + query_off: The last generated query with CTE optimization disabled. + query_on: The last generated query with CTE optimization enabled. + expect_cte: If True, assert query_on starts with WITH. If False, assert it does not. + """ + assert ( + not query_off.strip().upper().startswith("WITH") + ), f"CTE OFF should not produce CTE SQL, got: {query_off[:120]}" + if expect_cte: + assert ( + query_on.strip().upper().startswith("WITH") + ), f"CTE ON should produce CTE SQL, got: {query_on[:120]}" + else: + assert ( + not query_on.strip().upper().startswith("WITH") + ), f"Expected no CTE, but got CTE SQL: {query_on[:120]}" + + +# --------------------------------------------------------------------------- +# Union: two distinct DFs vs same DF ref +# --------------------------------------------------------------------------- + + +def test_two_dfs_same_sql_no_cte_in_connect_mode(session, test_table_name): + """Two independently constructed DataFrames with uuid_string() should NOT + be merged into a CTE in connect-compatible mode. + Values in each half should differ (independent evaluations).""" + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df_union = df1.union_all(df2) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + top, bot = _collect_uuid_halves(df_union) + assert ( + top != bot + ), "Two distinct DFs should produce different uuid values in each half" + + +def test_same_df_ref_still_uses_cte_in_connect_mode(session, test_table_name): + """A single DataFrame referenced twice (df.union_all(df)) should still be + deduplicated in connect-compatible mode. + Values in each half should be identical (same CTE evaluation reused).""" + base = session.table(test_table_name).select("a", "b") + df = base.select("a", uuid_string().alias("u")) + df_union = df.union_all(df) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + top, bot = _collect_uuid_halves(df_union) + assert ( + top == bot + ), "Same DF ref should produce identical uuid values in each half (CTE reuse)" + + +def test_connect_mode_with_random(session, test_table_name): + """random() with two separate DataFrames should not be CTE-merged in connect mode. + The random values in each half should differ (independent evaluations).""" + base = session.table(test_table_name).select("a") + df1 = base.select("a", random().alias("r")) + df2 = base.select("a", random().alias("r")) + df_union = df1.union_all(df2) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_union.collect() + half = len(result) // 2 + top_vals = [row["R"] for row in result[:half]] + bot_vals = [row["R"] for row in result[half:]] + assert ( + top_vals != bot_vals + ), "Two distinct DFs with random() should produce different values" + + +# --------------------------------------------------------------------------- +# Join: two distinct DFs with same SQL joined together +# --------------------------------------------------------------------------- + + +def test_join_two_distinct_dfs_no_cte_in_connect_mode(session, test_table_name): + """Two independently constructed DataFrames joined together should not + be CTE-merged when they are different objects in connect mode. + The uuid columns from each side should contain different values.""" + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df_joined = df1.join(df2, df1["a"] == df2["a"]) + + query_off, query_on = _get_query_cte_off_and_on(session, df_joined) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_joined.collect() + assert len(result) > 0 + lhs_uuids = {row[1] for row in result} + rhs_uuids = {row[3] for row in result} + assert ( + lhs_uuids != rhs_uuids + ), "Joined distinct DFs should have different uuid columns" + + +def test_join_same_df_ref_uses_cte_in_connect_mode(session, test_table_name): + """A single DataFrame self-joined (via copy.copy) should still trigger CTE + in connect-compatible mode because the underlying from_ is the same object.""" + base = session.table(test_table_name).select( + col("a").alias("a"), col("b").alias("b") + ) + df = base.select(col("a").alias("a"), lit(1).alias("v")) + df_joined = df.natural_join(copy.copy(df)) + + query_off, query_on = _get_query_cte_off_and_on(session, df_joined) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + +# --------------------------------------------------------------------------- +# Chained operations producing imbalanced subtrees +# --------------------------------------------------------------------------- + + +def test_chained_filter_union_imbalanced_no_cte_connect_mode(session, test_table_name): + """Chained operations that produce structurally different trees but + identical SQL at the leaf level. The two branches have different depths. + + Tree: union_all + / \\ + filter select + | | + select(uuid) select(uuid) ← different objects, same SQL + | | + base base ← same object (shared); skipped + by is_simple_select_entity + + In connect mode, the two select(uuid) nodes are different objects and + should not be CTE-merged, even though they have the same encoded id. + """ + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df1_filtered = df1.filter(col("a") > 1) + df_union = df1_filtered.union_all(df2) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_union.collect() + uuids = {row["U"] for row in result} + assert len(uuids) == len( + result + ), "All uuids should be unique across distinct DF branches" + + +def test_chained_agg_union_imbalanced_no_cte_connect_mode(session, test_table_name): + """Imbalanced tree where both branches aggregate independently. + + Tree: union_all + / \\ + group_by group_by + | | + select(uuid) select(uuid) ← different objects, same SQL + | | + base base ← same object (shared); skipped + by is_simple_select_entity + + The two select(uuid) nodes produce the same SQL but are different objects. + """ + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df1_agg = df1.group_by("a").count() + df2_agg = df2.group_by("a").count() + df_union = df1_agg.union_all(df2_agg) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_union.collect() + assert len(result) > 0, "Aggregated union should produce rows" + + +def test_chained_join_then_union_imbalanced_connect_mode(session, test_table_name): + """Three distinct DFs combined: two joined, then unioned with a third. + + Tree: union_all + / \\ + join select(uuid) ← df3: independent + / \\ + select(uuid) select(uuid) ← df1, df2: independent + | | + base base ← same object (shared); skipped + by is_simple_select_entity + + All three select(uuid) nodes have the same SQL but are different objects. + None should be CTE-merged in connect mode. + """ + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df3 = base.select("a", uuid_string().alias("u")) + df_joined = df1.join(df2, df1["a"] == df2["a"]).select( + df1["a"].alias("a"), df1["u"].alias("u") + ) + df_union = df_joined.union_all(df3) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_union.collect() + assert len(result) > 0, "Join-then-union should produce rows" + uuids = [row["U"] for row in result] + assert len(set(uuids)) == len( + uuids + ), "All uuids should be unique across the three distinct DFs" + + +def test_chained_operations_same_ref_shared_subtree_cte_connect_mode( + session, test_table_name +): + """Chained operations where the same object is used in both branches. + CTE should still apply in connect mode because it's the same Python object. + + Tree: union_all + / \\ + filter filter + | | + df df ← same object referenced twice + + """ + base = session.table(test_table_name).select( + col("a").alias("a"), col("b").alias("b") + ) + df = base.select(col("a").alias("a"), uuid_string().alias("u")) + left = df.filter(col("a") > 1) + right = df.filter(col("a") <= 9) + df_union = left.union_all(right) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + with mock.patch.object(session, "_cte_optimization_enabled", True): + result = df_union.collect() + left_uuids = {row["U"] for row in result if row["A"] > 1} + right_uuids = {row["U"] for row in result if row["A"] <= 9} + shared = left_uuids & right_uuids + assert len(shared) > 0, "Same DF ref with CTE should reuse uuids across branches" + + +def test_imbalanced_tree_non_simple_base_cte_connect_mode(session, test_table_name): + """Imbalanced tree where the shared base is NOT a simple select entity + (it has a filter), so it is eligible for CTE dedup. The two leaf branches + are different objects (different select(uuid) calls) so they should NOT be + CTE-merged, but the shared filtered base should be. + + Tree: union_all + / \\ + filter select(uuid) ← df2: distinct object, no CTE + | | + select(uuid) base ← same object in both branches; + | has filter → not simple select + base → eligible for CTE + + base is the same Python object shared by df1 and df2. Since base + is not a simple select entity (it has a filter), it qualifies for + CTE dedup even in connect mode. + """ + base = ( + session.table(test_table_name) + .select(col("a").alias("a"), col("b").alias("b")) + .filter(col("a") > 0) + ) + df1 = base.select(col("a").alias("a"), uuid_string().alias("u")) + df2 = base.select(col("a").alias("a"), uuid_string().alias("u")) + df_union = df1.filter(col("a") > 3).union_all(df2) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + result = df_union.collect() + uuids = {row["U"] for row in result} + assert len(uuids) == len( + result + ), "All uuids should be unique — the two select(uuid) are distinct objects" + + +def test_imbalanced_join_shared_base_cte_connect_mode(session, test_table_name): + """Imbalanced join tree where the shared base (same object, not simple + select) appears in both branches. CTE should be applied for the shared + base even though the outer branches are different objects. + + Tree: join + / \\ + filter select(uuid) ← df2: distinct object, no CTE + | | + select(uuid) base ← same object (has filter, + | not simple select → CTE) + base + + base is the same Python object in both branches, with a filter that + prevents is_simple_select_entity from excluding it. The two + select(uuid) nodes are different objects so they won't be merged. + """ + base = ( + session.table(test_table_name) + .select(col("a").alias("a"), col("b").alias("b")) + .filter(col("a") > 0) + ) + df1 = base.select(col("a").alias("a_l"), uuid_string().alias("u_l")) + df2 = base.select(col("a").alias("a_r"), uuid_string().alias("u_r")) + df_joined = df1.filter(col("a_l") > 3).join(df2, df1["a_l"] == df2["a_r"]) + + query_off, query_on = _get_query_cte_off_and_on(session, df_joined) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + result = df_joined.collect() + assert len(result) > 0 + lhs_uuids = {row["U_L"] for row in result} + rhs_uuids = {row["U_R"] for row in result} + assert ( + lhs_uuids != rhs_uuids + ), "Joined distinct DFs should have different uuid columns even with shared CTE base" + + +def test_distinct_objects_each_duplicated_still_cte(session, test_table_name): + """When multiple distinct objects share the same SQL AND each object is + itself referenced more than once, each should still be CTE-deduplicated. + + Tree: union_all (outer) + / \\ + union_all (left) union_all (right) + / \\ / \\ + df1 df1 df2 df2 + + df1 and df2 are different objects with the same SQL. + df1 appears twice → should be CTE'd. + df2 appears twice → should be CTE'd. + """ + base = session.table(test_table_name).select( + col("a").alias("a"), col("b").alias("b") + ) + df1 = base.select(col("a").alias("a"), lit(1).alias("v")) + df2 = base.select(col("a").alias("a"), lit(1).alias("v")) + left = df1.union_all(df1) + right = df2.union_all(df2) + df_outer = left.union_all(right) + + query_off, query_on = _get_query_cte_off_and_on(session, df_outer) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index 2014350f35..4e85fbb77f 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -86,6 +86,179 @@ def test_find_duplicate_subtrees(test_case): assert repeated_node_complexity == expected_repeated_node_complexity +def _create_mock_node(encoded_id, complexity=3): + """Helper to create a mock SnowflakePlan node for CTE tests.""" + node = mock.create_autospec(SnowflakePlan) + node.encoded_node_id_with_query = encoded_id + node.source_plan = None + node.cumulative_node_complexity = {PlanNodeCategory.COLUMN: complexity} + node.children_plan_nodes = [] + return node + + +def test_connect_mode_same_object_still_deduplicated(): + """When the same Python object is referenced multiple times (e.g. df.union_all(df)), + it should still be detected as a duplicate even in connect-compatible mode.""" + root = _create_mock_node("root_R") + shared_child = _create_mock_node("child_C") + leaf = _create_mock_node("leaf_L") + root.children_plan_nodes = [shared_child, shared_child] + shared_child.children_plan_nodes = [leaf] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + assert "child_C" in duplicated_ids + + +def test_connect_mode_different_objects_same_id_not_deduplicated(): + """When two different Python objects have the same encoded_node_id_with_query + (e.g. df1.union_all(df2) where df1 and df2 produce the same SQL), + find_duplicate_subtrees flags the encoded ID (raw count > 1), but + the per-object filtering in _replace_duplicate_node_with_cte will + skip them since each object appears only once.""" + root = _create_mock_node("root_R") + child_a = _create_mock_node("same_S") + child_b = _create_mock_node("same_S") + leaf_a = _create_mock_node("leaf_L") + leaf_b = _create_mock_node("leaf_L") + root.children_plan_nodes = [child_a, child_b] + child_a.children_plan_nodes = [leaf_a] + child_b.children_plan_nodes = [leaf_b] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + # The encoded IDs are flagged by raw count; per-object filtering + # happens downstream in _replace_duplicate_node_with_cte. + assert "same_S" in duplicated_ids + + +def test_connect_mode_mixed_shared_and_distinct_objects(): + """A tree with both shared objects (same ref) and distinct objects (different refs, + same encoded id). + + Tree: root + / \\ + left right (different objects, same encoded id "branch_B") + | | + shared shared (same object, appears twice → should be deduplicated) + + find_duplicate_subtrees flags both encoded IDs by raw count. + The per-object filtering in _replace_duplicate_node_with_cte will + skip left/right (each appears once) and only CTE-ify shared. + """ + root = _create_mock_node("root_R") + left = _create_mock_node("branch_B") + right = _create_mock_node("branch_B") + shared = _create_mock_node("shared_S") + root.children_plan_nodes = [left, right] + left.children_plan_nodes = [shared] + right.children_plan_nodes = [shared] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + # branch_B has raw count 2 (left + right) → flagged as duplicate. + # shared_S has raw count 2 but its only parent (branch_B) is also + # duplicated, so it's not the root of a duplicate subtree. + assert "branch_B" in duplicated_ids + assert "shared_S" not in duplicated_ids + + +def test_connect_mode_distinct_objects_each_duplicated(): + """When multiple distinct objects share the same encoded id AND each object + itself appears more than once, each should still be CTE-deduplicated. + + Tree: root + / \\ + union1 union2 + / \\ / \\ + df1 df1 df2 df2 ← df1 and df2 are different objects, same encoded id + df1 appears twice, df2 appears twice + + Both df1 and df2 should be deduplicated individually. + """ + root = _create_mock_node("root_R") + union1 = _create_mock_node("union1_U") + union2 = _create_mock_node("union2_U2") + df1 = _create_mock_node("same_S") + df2 = _create_mock_node("same_S") + root.children_plan_nodes = [union1, union2] + union1.children_plan_nodes = [df1, df1] + union2.children_plan_nodes = [df2, df2] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + assert "same_S" in duplicated_ids + + +def test_existing_cases_unchanged_in_connect_mode(): + """Existing test cases use the same object referenced multiple times, + so results should be the same even in connect-compatible mode.""" + for create_fn in [create_test_case1, create_test_case2]: + plan, expected_ids, expected_complexity = create_fn() + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + dup_ids, _ = find_duplicate_subtrees(plan) + assert dup_ids == expected_ids + + +def test_connect_mode_with_propagate_complexity_hist(): + """Verify that propagate_complexity_hist still works correctly in connect mode.""" + root = _create_mock_node("root_R") + shared = _create_mock_node("shared_S", complexity=50000) + root.children_plan_nodes = [shared, shared] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + dup_ids, complexity_hist = find_duplicate_subtrees( + root, propagate_complexity_hist=True + ) + assert "shared_S" in dup_ids + assert complexity_hist is not None + assert complexity_hist[1] == 2 # 50000 falls in bin 1 (> 10,000, <= 100,000) + + +def test_connect_mode_mixed_duplicated_and_unique_objects(): + """When multiple distinct objects share the same encoded id but only some + appear more than once, the encoded ID should still be flagged as + duplicated (because at least one object is genuinely duplicated). + + Tree: root + / \\ + union1 union2 + / \\ / \\ + df1 df1 df2 df3 ← df1 appears 2x (duplicated), df2 and df3 appear 1x each + + The encoded ID "same_S" should be in duplicated_node_ids because df1 + appears twice. The per-object filtering (df2/df3 not replaced) is + handled downstream in _replace_duplicate_node_with_cte. + """ + root = _create_mock_node("root_R") + union1 = _create_mock_node("union1_U") + union2 = _create_mock_node("union2_U2") + df1 = _create_mock_node("same_S") + df2 = _create_mock_node("same_S") + df3 = _create_mock_node("same_S") + root.children_plan_nodes = [union1, union2] + union1.children_plan_nodes = [df1, df1] + union2.children_plan_nodes = [df2, df3] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + assert "same_S" in duplicated_ids + + def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer): sql_text = "select 1 as a, 2 as b" select_sql_node = SelectSQL(