From 9df0b2198866d22ccbf1b98bd72e1d23546ca4d7 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 27 Mar 2026 14:14:08 -0700 Subject: [PATCH 1/5] implement dup node detection --- .../snowpark/_internal/compiler/cte_utils.py | 43 +- tests/integ/test_cte_connect_mode_dedup.py | 414 ++++++++++++++++++ tests/unit/test_cte.py | 101 +++++ 3 files changed, 551 insertions(+), 7 deletions(-) create mode 100644 tests/integ/test_cte_connect_mode_dedup.py diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index f9b46f4276..8deb2e6cd1 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -16,6 +16,7 @@ WithQueryBlock, ) from snowflake.snowpark._internal.utils import is_sql_select_statement +import snowflake.snowpark.context as context if TYPE_CHECKING: from snowflake.snowpark._internal.compiler.utils import TreeNode # pragma: no cover @@ -57,6 +58,14 @@ def find_duplicate_subtrees( # during this process invalid_ids_for_deduplication = set() + # When _is_snowpark_connect_compatible_mode is enabled, we track unique + # object identities per encoded_node_id to avoid merging nodes from + # different DataFrame construction calls that happen to produce + # identical SQL. Only the same Python object appearing multiple times + # (e.g. df.union_all(df)) should be treated as a duplicate. + use_object_identity = context._is_snowpark_connect_compatible_mode + object_ids_per_node_id: Dict[str, Set[int]] = defaultdict(set) + from snowflake.snowpark._internal.analyzer.select_statement import ( Selectable, SelectStatement, @@ -115,15 +124,17 @@ def traverse(root: "TreeNode") -> None: while len(current_level) > 0: next_level = [] for node in current_level: - id_node_map[node.encoded_node_id_with_query].append(node) + encoded_id = node.encoded_node_id_with_query + id_node_map[encoded_id].append(node) + + if use_object_identity: + object_ids_per_node_id[encoded_id].add(id(node)) if is_select_from_file_node(node): - invalid_ids_for_deduplication.add(node.encoded_node_id_with_query) + invalid_ids_for_deduplication.add(encoded_id) for child in node.children_plan_nodes: - id_parents_map[child.encoded_node_id_with_query].add( - node.encoded_node_id_with_query - ) + id_parents_map[child.encoded_node_id_with_query].add(encoded_id) next_level.append(child) current_level = next_level @@ -138,6 +149,24 @@ def traverse(root: "TreeNode") -> None: next_level.append(parent_id) current_level = next_level + def _node_occurrence_count(encoded_node_id_with_query: str) -> int: + """How many times this node appears in the tree. + + In connect-compatible mode, different Python objects with the same + encoded id are counted as distinct nodes (occurrence = 1 each). + A true duplicate requires the *same* object referenced from + multiple parents. + """ + total = len(id_node_map[encoded_node_id_with_query]) + if use_object_identity: + unique_objects = len(object_ids_per_node_id[encoded_node_id_with_query]) + if unique_objects > 1: + # Multiple distinct objects share the same encoded id. + # None of them should be considered duplicates of each other. + return 1 + # All entries are the same object (same id()) → keep original count. + return total + def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: # when a sql query is a select statement, its encoded_node_id_with_query # contains _, which is used to separate the query id and node type name. @@ -154,10 +183,10 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: if encoded_node_id_with_query in invalid_ids_for_deduplication: return False - is_duplicate_node = len(id_node_map[encoded_node_id_with_query]) > 1 + is_duplicate_node = _node_occurrence_count(encoded_node_id_with_query) > 1 if is_duplicate_node: is_any_parent_unique_node = any( - len(id_node_map[id]) == 1 + _node_occurrence_count(id) == 1 for id in id_parents_map[encoded_node_id_with_query] ) if is_any_parent_unique_node: diff --git a/tests/integ/test_cte_connect_mode_dedup.py b/tests/integ/test_cte_connect_mode_dedup.py new file mode 100644 index 0000000000..b17df1f8f5 --- /dev/null +++ b/tests/integ/test_cte_connect_mode_dedup.py @@ -0,0 +1,414 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +"""CTE deduplication tests for _is_snowpark_connect_compatible_mode. + +When connect-compatible mode is enabled, two different DataFrame objects that +produce the same SQL should NOT be merged into a single CTE. Only the same +Python object referenced multiple times (e.g. df.union_all(df)) should be +deduplicated. + +This prevents incorrect results when non-deterministic functions like +uuid_string() are used: df1.union_all(df2) should produce two independent +evaluations, not a single CTE referenced twice. + +Tests cover: +1. Union with two distinct DFs (no CTE in connect mode) +2. Union with same DF ref (CTE still applies in connect mode) +3. Join-based triggers with distinct DFs +4. Chained operations producing imbalanced subtrees +""" + +import copy +from unittest import mock + +import pytest + +import snowflake.snowpark.context as context +from snowflake.snowpark.functions import col, lit, random, uuid_string +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) + +pytestmark = [ + pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="CTE is a SQL feature", + run=False, + ), +] + + +@pytest.fixture(scope="module") +def test_table_name(session): + """Create a shared test table with columns a (INT) and b (INT).""" + name = random_name_for_temp_object(TempObjectType.TABLE) + session.sql( + f""" + CREATE OR REPLACE TEMP TABLE {name} (a INT, b INT) + """ + ).collect() + session.sql( + f""" + INSERT INTO {name} VALUES (1, 2), (3, 4), (5, 6), (7, 8), (9, 10) + """ + ).collect() + yield name + session.sql(f"DROP TABLE IF EXISTS {name}").collect() + + +@pytest.fixture(autouse=True) +def enable_connect_compatible_mode(): + """Patch _is_snowpark_connect_compatible_mode to True for all tests in this module.""" + with mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + yield + + +def _get_query_cte_off_and_on(session, df): + """Get the last generated query with CTE off and CTE on using mock.patch.object.""" + with mock.patch.object(session, "_cte_optimization_enabled", False): + query_off = df.queries["queries"][-1] + with mock.patch.object(session, "_cte_optimization_enabled", True): + query_on = df.queries["queries"][-1] + return query_off, query_on + + +def _collect_uuid_halves(df): + """Collect a union DataFrame and return (top_half_uuids, bottom_half_uuids).""" + result = df.collect() + half = len(result) // 2 + assert half > 0, "Need at least 2 rows to compare halves" + top = [row["U"] for row in result[:half]] + bot = [row["U"] for row in result[half:]] + return top, bot + + +def assert_cte_sql_shape( + query_off: str, query_on: str, expect_cte: bool = True +) -> None: + """Assert that generated SQL has the expected CTE shape (no result comparison). + + Args: + query_off: The last generated query with CTE optimization disabled. + query_on: The last generated query with CTE optimization enabled. + expect_cte: If True, assert query_on starts with WITH. If False, assert it does not. + """ + assert ( + not query_off.strip().upper().startswith("WITH") + ), f"CTE OFF should not produce CTE SQL, got: {query_off[:120]}" + if expect_cte: + assert ( + query_on.strip().upper().startswith("WITH") + ), f"CTE ON should produce CTE SQL, got: {query_on[:120]}" + else: + assert ( + not query_on.strip().upper().startswith("WITH") + ), f"Expected no CTE, but got CTE SQL: {query_on[:120]}" + + +# --------------------------------------------------------------------------- +# Union: two distinct DFs vs same DF ref +# --------------------------------------------------------------------------- + + +def test_two_dfs_same_sql_no_cte_in_connect_mode(session, test_table_name): + """Two independently constructed DataFrames with uuid_string() should NOT + be merged into a CTE in connect-compatible mode. + Values in each half should differ (independent evaluations).""" + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df_union = df1.union_all(df2) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + top, bot = _collect_uuid_halves(df_union) + assert ( + top != bot + ), "Two distinct DFs should produce different uuid values in each half" + + +def test_same_df_ref_still_uses_cte_in_connect_mode(session, test_table_name): + """A single DataFrame referenced twice (df.union_all(df)) should still be + deduplicated in connect-compatible mode. + Values in each half should be identical (same CTE evaluation reused).""" + base = session.table(test_table_name).select("a", "b") + df = base.select("a", uuid_string().alias("u")) + df_union = df.union_all(df) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + top, bot = _collect_uuid_halves(df_union) + assert ( + top == bot + ), "Same DF ref should produce identical uuid values in each half (CTE reuse)" + + +def test_connect_mode_with_random(session, test_table_name): + """random() with two separate DataFrames should not be CTE-merged in connect mode. + The random values in each half should differ (independent evaluations).""" + base = session.table(test_table_name).select("a") + df1 = base.select("a", random().alias("r")) + df2 = base.select("a", random().alias("r")) + df_union = df1.union_all(df2) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_union.collect() + half = len(result) // 2 + top_vals = [row["R"] for row in result[:half]] + bot_vals = [row["R"] for row in result[half:]] + assert ( + top_vals != bot_vals + ), "Two distinct DFs with random() should produce different values" + + +# --------------------------------------------------------------------------- +# Join: two distinct DFs with same SQL joined together +# --------------------------------------------------------------------------- + + +def test_join_two_distinct_dfs_no_cte_in_connect_mode(session, test_table_name): + """Two independently constructed DataFrames joined together should not + be CTE-merged when they are different objects in connect mode. + The uuid columns from each side should contain different values.""" + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df_joined = df1.join(df2, df1["a"] == df2["a"]) + + query_off, query_on = _get_query_cte_off_and_on(session, df_joined) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_joined.collect() + assert len(result) > 0 + lhs_uuids = {row[1] for row in result} + rhs_uuids = {row[3] for row in result} + assert ( + lhs_uuids != rhs_uuids + ), "Joined distinct DFs should have different uuid columns" + + +def test_join_same_df_ref_uses_cte_in_connect_mode(session, test_table_name): + """A single DataFrame self-joined (via copy.copy) should still trigger CTE + in connect-compatible mode because the underlying from_ is the same object.""" + base = session.table(test_table_name).select( + col("a").alias("a"), col("b").alias("b") + ) + df = base.select(col("a").alias("a"), lit(1).alias("v")) + df_joined = df.natural_join(copy.copy(df)) + + query_off, query_on = _get_query_cte_off_and_on(session, df_joined) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + +# --------------------------------------------------------------------------- +# Chained operations producing imbalanced subtrees +# --------------------------------------------------------------------------- + + +def test_chained_filter_union_imbalanced_no_cte_connect_mode(session, test_table_name): + """Chained operations that produce structurally different trees but + identical SQL at the leaf level. The two branches have different depths. + + Tree: union_all + / \\ + filter select + | | + select(uuid) select(uuid) ← different objects, same SQL + | | + base base ← same object (shared); skipped + by is_simple_select_entity + + In connect mode, the two select(uuid) nodes are different objects and + should not be CTE-merged, even though they have the same encoded id. + """ + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df1_filtered = df1.filter(col("a") > 1) + df_union = df1_filtered.union_all(df2) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_union.collect() + uuids = {row["U"] for row in result} + assert len(uuids) == len( + result + ), "All uuids should be unique across distinct DF branches" + + +def test_chained_agg_union_imbalanced_no_cte_connect_mode(session, test_table_name): + """Imbalanced tree where both branches aggregate independently. + + Tree: union_all + / \\ + group_by group_by + | | + select(uuid) select(uuid) ← different objects, same SQL + | | + base base ← same object (shared); skipped + by is_simple_select_entity + + The two select(uuid) nodes produce the same SQL but are different objects. + """ + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df1_agg = df1.group_by("a").count() + df2_agg = df2.group_by("a").count() + df_union = df1_agg.union_all(df2_agg) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_union.collect() + assert len(result) > 0, "Aggregated union should produce rows" + + +def test_chained_join_then_union_imbalanced_connect_mode(session, test_table_name): + """Three distinct DFs combined: two joined, then unioned with a third. + + Tree: union_all + / \\ + join select(uuid) ← df3: independent + / \\ + select(uuid) select(uuid) ← df1, df2: independent + | | + base base ← same object (shared); skipped + by is_simple_select_entity + + All three select(uuid) nodes have the same SQL but are different objects. + None should be CTE-merged in connect mode. + """ + base = session.table(test_table_name).select("a", "b") + df1 = base.select("a", uuid_string().alias("u")) + df2 = base.select("a", uuid_string().alias("u")) + df3 = base.select("a", uuid_string().alias("u")) + df_joined = df1.join(df2, df1["a"] == df2["a"]).select( + df1["a"].alias("a"), df1["u"].alias("u") + ) + df_union = df_joined.union_all(df3) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=False) + + result = df_union.collect() + assert len(result) > 0, "Join-then-union should produce rows" + uuids = [row["U"] for row in result] + assert len(set(uuids)) == len( + uuids + ), "All uuids should be unique across the three distinct DFs" + + +def test_chained_operations_same_ref_shared_subtree_cte_connect_mode( + session, test_table_name +): + """Chained operations where the same object is used in both branches. + CTE should still apply in connect mode because it's the same Python object. + + Tree: union_all + / \\ + filter filter + | | + df df ← same object referenced twice + + """ + base = session.table(test_table_name).select( + col("a").alias("a"), col("b").alias("b") + ) + df = base.select(col("a").alias("a"), uuid_string().alias("u")) + left = df.filter(col("a") > 1) + right = df.filter(col("a") <= 9) + df_union = left.union_all(right) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + with mock.patch.object(session, "_cte_optimization_enabled", True): + result = df_union.collect() + left_uuids = {row["U"] for row in result if row["A"] > 1} + right_uuids = {row["U"] for row in result if row["A"] <= 9} + shared = left_uuids & right_uuids + assert len(shared) > 0, "Same DF ref with CTE should reuse uuids across branches" + + +def test_imbalanced_tree_non_simple_base_cte_connect_mode(session, test_table_name): + """Imbalanced tree where the shared base is NOT a simple select entity + (it has a filter), so it is eligible for CTE dedup. The two leaf branches + are different objects (different select(uuid) calls) so they should NOT be + CTE-merged, but the shared filtered base should be. + + Tree: union_all + / \\ + filter select(uuid) ← df2: distinct object, no CTE + | | + select(uuid) base ← same object in both branches; + | has filter → not simple select + base → eligible for CTE + + base is the same Python object shared by df1 and df2. Since base + is not a simple select entity (it has a filter), it qualifies for + CTE dedup even in connect mode. + """ + base = ( + session.table(test_table_name) + .select(col("a").alias("a"), col("b").alias("b")) + .filter(col("a") > 0) + ) + df1 = base.select(col("a").alias("a"), uuid_string().alias("u")) + df2 = base.select(col("a").alias("a"), uuid_string().alias("u")) + df_union = df1.filter(col("a") > 3).union_all(df2) + + query_off, query_on = _get_query_cte_off_and_on(session, df_union) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + result = df_union.collect() + uuids = {row["U"] for row in result} + assert len(uuids) == len( + result + ), "All uuids should be unique — the two select(uuid) are distinct objects" + + +def test_imbalanced_join_shared_base_cte_connect_mode(session, test_table_name): + """Imbalanced join tree where the shared base (same object, not simple + select) appears in both branches. CTE should be applied for the shared + base even though the outer branches are different objects. + + Tree: join + / \\ + filter select(uuid) ← df2: distinct object, no CTE + | | + select(uuid) base ← same object (has filter, + | not simple select → CTE) + base + + base is the same Python object in both branches, with a filter that + prevents is_simple_select_entity from excluding it. The two + select(uuid) nodes are different objects so they won't be merged. + """ + base = ( + session.table(test_table_name) + .select(col("a").alias("a"), col("b").alias("b")) + .filter(col("a") > 0) + ) + df1 = base.select(col("a").alias("a_l"), uuid_string().alias("u_l")) + df2 = base.select(col("a").alias("a_r"), uuid_string().alias("u_r")) + df_joined = df1.filter(col("a_l") > 3).join(df2, df1["a_l"] == df2["a_r"]) + + query_off, query_on = _get_query_cte_off_and_on(session, df_joined) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) + + result = df_joined.collect() + assert len(result) > 0 + lhs_uuids = {row["U_L"] for row in result} + rhs_uuids = {row["U_R"] for row in result} + assert ( + lhs_uuids != rhs_uuids + ), "Joined distinct DFs should have different uuid columns even with shared CTE base" diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index 2014350f35..3aeec6c73d 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -86,6 +86,107 @@ def test_find_duplicate_subtrees(test_case): assert repeated_node_complexity == expected_repeated_node_complexity +def _create_mock_node(encoded_id, complexity=3): + """Helper to create a mock SnowflakePlan node for CTE tests.""" + node = mock.create_autospec(SnowflakePlan) + node.encoded_node_id_with_query = encoded_id + node.source_plan = None + node.cumulative_node_complexity = {PlanNodeCategory.COLUMN: complexity} + node.children_plan_nodes = [] + return node + + +def test_connect_mode_same_object_still_deduplicated(): + """When the same Python object is referenced multiple times (e.g. df.union_all(df)), + it should still be detected as a duplicate even in connect-compatible mode.""" + root = _create_mock_node("root_R") + shared_child = _create_mock_node("child_C") + leaf = _create_mock_node("leaf_L") + root.children_plan_nodes = [shared_child, shared_child] + shared_child.children_plan_nodes = [leaf] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + assert "child_C" in duplicated_ids + + +def test_connect_mode_different_objects_same_id_not_deduplicated(): + """When two different Python objects have the same encoded_node_id_with_query + (e.g. df1.union_all(df2) where df1 and df2 produce the same SQL), + they should NOT be treated as duplicates in connect-compatible mode.""" + root = _create_mock_node("root_R") + child_a = _create_mock_node("same_S") + child_b = _create_mock_node("same_S") + leaf_a = _create_mock_node("leaf_L") + leaf_b = _create_mock_node("leaf_L") + root.children_plan_nodes = [child_a, child_b] + child_a.children_plan_nodes = [leaf_a] + child_b.children_plan_nodes = [leaf_b] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + assert len(duplicated_ids) == 0 + + +def test_connect_mode_mixed_shared_and_distinct_objects(): + """A tree with both shared objects (same ref) and distinct objects (different refs, + same encoded id). Only the shared object should be deduplicated in connect mode. + + Tree: root + / \\ + left right (different objects, same encoded id "branch_B") + | | + shared shared (same object, appears twice → should be deduplicated) + """ + root = _create_mock_node("root_R") + left = _create_mock_node("branch_B") + right = _create_mock_node("branch_B") + shared = _create_mock_node("shared_S") + root.children_plan_nodes = [left, right] + left.children_plan_nodes = [shared] + right.children_plan_nodes = [shared] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + assert "shared_S" in duplicated_ids + assert "branch_B" not in duplicated_ids + + +def test_existing_cases_unchanged_in_connect_mode(): + """Existing test cases use the same object referenced multiple times, + so results should be the same even in connect-compatible mode.""" + for create_fn in [create_test_case1, create_test_case2]: + plan, expected_ids, expected_complexity = create_fn() + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + dup_ids, _ = find_duplicate_subtrees(plan) + assert dup_ids == expected_ids + + +def test_connect_mode_with_propagate_complexity_hist(): + """Verify that propagate_complexity_hist still works correctly in connect mode.""" + root = _create_mock_node("root_R") + shared = _create_mock_node("shared_S", complexity=50000) + root.children_plan_nodes = [shared, shared] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + dup_ids, complexity_hist = find_duplicate_subtrees( + root, propagate_complexity_hist=True + ) + assert "shared_S" in dup_ids + assert complexity_hist is not None + assert complexity_hist[1] == 2 # 50000 falls in bin 1 (> 10,000, <= 100,000) + + def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer): sql_text = "select 1 as a, 2 as b" select_sql_node = SelectSQL( From de9d72202c026b68ffe1b5fa7a547d6f58189c8c Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 27 Mar 2026 15:20:38 -0700 Subject: [PATCH 2/5] add tests --- .../snowpark/_internal/compiler/cte_utils.py | 13 ++- tests/integ/test_cte.py | 104 ++++++++++++++---- 2 files changed, 87 insertions(+), 30 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index 8deb2e6cd1..cea21cf1e8 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -159,12 +159,13 @@ def _node_occurrence_count(encoded_node_id_with_query: str) -> int: """ total = len(id_node_map[encoded_node_id_with_query]) if use_object_identity: - unique_objects = len(object_ids_per_node_id[encoded_node_id_with_query]) - if unique_objects > 1: - # Multiple distinct objects share the same encoded id. - # None of them should be considered duplicates of each other. - return 1 - # All entries are the same object (same id()) → keep original count. + # If there are multiple distinct objects with the same encoded id, return 1. + # Otherwise, return the total number of occurrences. + return ( + 1 + if len(object_ids_per_node_id[encoded_node_id_with_query]) > 1 + else total + ) return total def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 1adcbe3558..49a2a4ef0f 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -5,6 +5,7 @@ import re import tracemalloc from unittest import mock +import uuid import pytest @@ -60,8 +61,17 @@ WITH = "WITH" +@pytest.fixture(params=[False, True], ids=["connect_mode_off", "connect_mode_on"]) +def is_connect_mode(request): + """Parametrize every test over _is_snowpark_connect_compatible_mode.""" + with mock.patch.object( + context, "_is_snowpark_connect_compatible_mode", request.param + ): + yield request.param + + @pytest.fixture(autouse=True) -def setup(request, session): +def setup(request, session, is_connect_mode): is_cte_optimization_enabled = session._cte_optimization_enabled is_query_compilation_enabled = session._query_compilation_stage_enabled session._query_compilation_stage_enabled = True @@ -251,7 +261,9 @@ def test_binary(session, type, action): assert len(plan_queries["post_actions"]) == 1 -def test_join_with_alias_dataframe(session): +def test_join_with_alias_dataframe(session, is_connect_mode): + c1 = f"col1_{uuid.uuid4().hex[:8]}" + c2 = f"col2_{uuid.uuid4().hex[:8]}" expected_describe_count = ( 3 if (session.reduce_describe_query_enabled and session.sql_simplifier_enabled) @@ -260,11 +272,11 @@ def test_join_with_alias_dataframe(session): with SqlCounter( query_count=2, describe_count=expected_describe_count, join_count=2 ): - df1 = session.create_dataframe([[1, 6]], schema=["col1", "col2"]) + df1 = session.create_dataframe([[1, 6]], schema=[c1, c2]) df_res = ( df1.alias("L") - .join(df1.alias("R"), col("L", "col1") == col("R", "col1")) - .select(col("L", "col1"), col("R", "col2")) + .join(df1.alias("R"), col("L", c1) == col("R", c1)) + .select(col("L", c1), col("R", c2)) ) session._cte_optimization_enabled = False @@ -355,7 +367,7 @@ def test_join_with_set_operation(session): @pytest.mark.parametrize("type, action", binary_operations) -def test_variable_binding_binary(session, type, action): +def test_variable_binding_binary(session, type, action, is_connect_mode): df1 = session.sql( "select $1 as a, $2 as b from values (?, ?), (?, ?)", params=[1, "a", 2, "b"] ) @@ -372,10 +384,12 @@ def test_variable_binding_binary(session, type, action): join_count = 1 if type == "union": union_count = 1 + # df1 and df3 are different Python objects with the same SQL. + # In connect mode they should NOT be deduplicated. check_result( session, action(df1, df3), - expect_cte_optimized=True, + expect_cte_optimized=not is_connect_mode, query_count=1, describe_count=0, union_count=union_count, @@ -551,21 +565,24 @@ def test_number_of_ctes(session, type, action): ) -def test_different_df_same_query(session): +def test_different_df_same_query(session, is_connect_mode): df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") df = df2.union_all(df1) + # df1 and df2 are different Python objects with the same SQL. + # In connect mode they should NOT be deduplicated. check_result( session, df, - expect_cte_optimized=True, + expect_cte_optimized=not is_connect_mode, query_count=1, describe_count=0, union_count=1, join_count=0, ) with SqlCounter(query_count=0, describe_count=0): - assert count_number_of_ctes(df.queries["queries"][-1]) == 1 + expected_cte_count = 0 if is_connect_mode else 1 + assert count_number_of_ctes(df.queries["queries"][-1]) == expected_cte_count def test_same_duplicate_subtree(session): @@ -624,7 +641,7 @@ def test_same_duplicate_subtree(session): @pytest.mark.parametrize("use_different_df", [True, False]) -def test_cte_preserves_join_suffix_aliases(session, use_different_df): +def test_cte_preserves_join_suffix_aliases(session, use_different_df, is_connect_mode): df_ad_group = session.create_dataframe( [["1048771", "group_1", "campaign_1"]], schema=["ACCOUNT_ID", "AD_GROUP_ID", "CAMPAIGN_ID"], @@ -695,8 +712,14 @@ def test_cte_preserves_join_suffix_aliases(session, use_different_df): assert 'ON ("AD_GROUP_ID" = "AD_GROUP_ID")' not in union_sql # when using different df_ad_group with disambiguation, because rsuffix in join, # they have different alias map (expr_to_alias), so they are considered different and we can't convert them to a CTE - # However there is still a CTE for create_dataframe call - assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == 1 + # However there is still a CTE for create_dataframe call. + # In connect mode with use_different_df, all create_dataframe calls are + # distinct objects so no CTEs are produced. + if is_connect_mode and use_different_df: + expected_cte_count = 0 + else: + expected_cte_count = 1 + assert count_number_of_ctes(Utils.normalize_sql(union_sql)) == expected_cte_count @pytest.mark.parametrize( @@ -807,7 +830,7 @@ def test_explain(session): assert "WITH SNOWPARK_TEMP_CTE" in explain_string -def test_sql_simplifier(session): +def test_sql_simplifier(session, is_connect_mode): if not session._sql_simplifier_enabled: pytest.skip("SQL simplifier is not enabled") @@ -822,6 +845,9 @@ def test_sql_simplifier(session): df2 = df1.select("a", "b") df3 = df1.select("a", "b").select("a", "b") df4 = df1.union_by_name(df2).union_by_name(df3) + # df1, df2, df3 are different Python objects that simplify to the same SQL. + # In connect mode they are not deduplicated, but df (create_dataframe) is + # still the same object appearing across all branches → still CTE'd. check_result( session, df4, @@ -832,11 +858,35 @@ def test_sql_simplifier(session): join_count=0, ) with SqlCounter(query_count=0, describe_count=0): - # after applying sql simplifier, there is only one CTE (df1, df2, df3 have the same query) - assert ( - count_number_of_ctes(Utils.normalize_sql(df4.queries["queries"][-1])) == 1 - ) - assert Utils.normalize_sql(df4.queries["queries"][-1]).count(filter_clause) == 1 + if is_connect_mode: + # df1, df2, df3 are different objects → not merged. + # Only df (create_dataframe) is the same object across all branches → 1 CTE. + # Generated SQL: + # WITH CTE AS (SELECT $1 AS "A", $2 AS "B" FROM VALUES ...) + # (SELECT "A","B" FROM (CTE) WHERE ("A"=1)) + # UNION (SELECT "A","B" FROM (CTE) WHERE ("A"=1)) + # UNION (SELECT "A","B" FROM (CTE) WHERE ("A"=1)) + assert ( + count_number_of_ctes(Utils.normalize_sql(df4.queries["queries"][-1])) + == 1 + ) + assert ( + Utils.normalize_sql(df4.queries["queries"][-1]).count(filter_clause) + == 3 + ) + else: + # df1, df2, df3 all simplify to the same SQL and are merged into 1 CTE. + # Generated SQL: + # WITH CTE AS (SELECT "A","B" FROM (VALUES ...) WHERE ("A"=1)) + # (CTE) UNION (CTE) UNION (CTE) + assert ( + count_number_of_ctes(Utils.normalize_sql(df4.queries["queries"][-1])) + == 1 + ) + assert ( + Utils.normalize_sql(df4.queries["queries"][-1]).count(filter_clause) + == 1 + ) df5 = df1.join(df2).join(df3) check_result( @@ -988,18 +1038,20 @@ def test_sql_non_select(session): ) -def test_sql_with(session): +def test_sql_with(session, is_connect_mode): df1 = session.sql("with t as (select 1 as A) select * from t") df2 = session.sql("with t as (select 1 as A) select * from t") df_result = df1.union(df2).select("A").filter(lit(True)) + # df1 and df2 are different Python objects with the same SQL. + # In connect mode they should NOT be deduplicated. check_result( session, df_result, # with ... select is also treated as a select query # see is_sql_select_statement() function - expect_cte_optimized=True, + expect_cte_optimized=not is_connect_mode, query_count=1, describe_count=0, union_count=1, @@ -1325,7 +1377,7 @@ def test_table_select_cte(session): ], ) def test_dataframe_queries_with_cte_reuses_schema_cache( - session, reduce_describe_enabled, expected_describe_counts + session, reduce_describe_enabled, expected_describe_counts, is_connect_mode ): """Test that calling dataframe.queries (not same dataframe but same operation) multiple times with CTE optimization does not issue extra DESCRIBE queries when reduce_describe_query_enabled is True. @@ -1335,9 +1387,13 @@ def test_dataframe_queries_with_cte_reuses_schema_cache( identical SQL (with same CTE names), allowing the schema cache to hit. """ + # randomize column names to avoid schema cache hits from prior test runs in the same session. + col_a = f"col_{uuid.uuid4().hex[:8]}" + col_b = f"col_{uuid.uuid4().hex[:8]}" + def create_cte_dataframe(): """Create a DataFrame that triggers CTE optimization (same df used twice).""" - df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df = session.create_dataframe([[1, 2], [3, 4]], schema=[col_a, col_b]) return df.union_all(df) def access_queries_and_schema(df): @@ -1347,7 +1403,7 @@ def access_queries_and_schema(df): with mock.patch.object( session, "_reduce_describe_query_enabled", reduce_describe_enabled - ), mock.patch.object(context, "_is_snowpark_connect_compatible_mode", True): + ): for expected_describe_count in expected_describe_counts: df_union = create_cte_dataframe() with SqlCounter(query_count=0, describe_count=expected_describe_count): From c9dc06b3ecc6a0bdd44d737728204c4cb9a168cd Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Fri, 27 Mar 2026 17:02:54 -0700 Subject: [PATCH 3/5] bug fix --- .../snowpark/_internal/compiler/cte_utils.py | 21 +++++++++----- tests/integ/test_cte_connect_mode_dedup.py | 27 +++++++++++++++++ tests/unit/test_cte.py | 29 +++++++++++++++++++ 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index cea21cf1e8..a2d8af56e8 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -4,7 +4,7 @@ import hashlib import logging -from collections import defaultdict +from collections import Counter, defaultdict from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( @@ -156,16 +156,21 @@ def _node_occurrence_count(encoded_node_id_with_query: str) -> int: encoded id are counted as distinct nodes (occurrence = 1 each). A true duplicate requires the *same* object referenced from multiple parents. + + When multiple distinct objects share the same encoded id, we return + the max occurrence count among any single object. This handles cases + like union(union(df1, df1), union(df2, df2)) where df1 and df2 + produce identical SQL but df1 itself appears twice and should be + CTE-deduplicated. """ total = len(id_node_map[encoded_node_id_with_query]) if use_object_identity: - # If there are multiple distinct objects with the same encoded id, return 1. - # Otherwise, return the total number of occurrences. - return ( - 1 - if len(object_ids_per_node_id[encoded_node_id_with_query]) > 1 - else total - ) + object_ids = object_ids_per_node_id[encoded_node_id_with_query] + if len(object_ids) > 1: + id_counts = Counter( + id(node) for node in id_node_map[encoded_node_id_with_query] + ) + return max(id_counts.values()) return total def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: diff --git a/tests/integ/test_cte_connect_mode_dedup.py b/tests/integ/test_cte_connect_mode_dedup.py index b17df1f8f5..ab23f710ee 100644 --- a/tests/integ/test_cte_connect_mode_dedup.py +++ b/tests/integ/test_cte_connect_mode_dedup.py @@ -412,3 +412,30 @@ def test_imbalanced_join_shared_base_cte_connect_mode(session, test_table_name): assert ( lhs_uuids != rhs_uuids ), "Joined distinct DFs should have different uuid columns even with shared CTE base" + + +def test_distinct_objects_each_duplicated_still_cte(session, test_table_name): + """When multiple distinct objects share the same SQL AND each object is + itself referenced more than once, each should still be CTE-deduplicated. + + Tree: union_all (outer) + / \\ + union_all (left) union_all (right) + / \\ / \\ + df1 df1 df2 df2 + + df1 and df2 are different objects with the same SQL. + df1 appears twice → should be CTE'd. + df2 appears twice → should be CTE'd. + """ + base = session.table(test_table_name).select( + col("a").alias("a"), col("b").alias("b") + ) + df1 = base.select(col("a").alias("a"), lit(1).alias("v")) + df2 = base.select(col("a").alias("a"), lit(1).alias("v")) + left = df1.union_all(df1) + right = df2.union_all(df2) + df_outer = left.union_all(right) + + query_off, query_on = _get_query_cte_off_and_on(session, df_outer) + assert_cte_sql_shape(query_off, query_on, expect_cte=True) diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index 3aeec6c73d..b6db420f8c 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -158,6 +158,35 @@ def test_connect_mode_mixed_shared_and_distinct_objects(): assert "branch_B" not in duplicated_ids +def test_connect_mode_distinct_objects_each_duplicated(): + """When multiple distinct objects share the same encoded id AND each object + itself appears more than once, each should still be CTE-deduplicated. + + Tree: root + / \\ + union1 union2 + / \\ / \\ + df1 df1 df2 df2 ← df1 and df2 are different objects, same encoded id + df1 appears twice, df2 appears twice + + Both df1 and df2 should be deduplicated individually. + """ + root = _create_mock_node("root_R") + union1 = _create_mock_node("union1_U") + union2 = _create_mock_node("union2_U2") + df1 = _create_mock_node("same_S") + df2 = _create_mock_node("same_S") + root.children_plan_nodes = [union1, union2] + union1.children_plan_nodes = [df1, df1] + union2.children_plan_nodes = [df2, df2] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + assert "same_S" in duplicated_ids + + def test_existing_cases_unchanged_in_connect_mode(): """Existing test cases use the same object referenced multiple times, so results should be the same even in connect-compatible mode.""" From 52f6cf608b116fb50d3f522f783c1f1b802144a8 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Mon, 6 Apr 2026 14:49:03 -0700 Subject: [PATCH 4/5] update --- .../snowpark/_internal/compiler/cte_utils.py | 42 +++---------- .../compiler/repeated_subquery_elimination.py | 57 +++++++++++++----- tests/integ/test_cte.py | 59 +++++++++++++++++++ tests/unit/test_cte.py | 53 +++++++++++++++-- 4 files changed, 157 insertions(+), 54 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index a2d8af56e8..7165cccf8e 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -4,7 +4,7 @@ import hashlib import logging -from collections import Counter, defaultdict +from collections import defaultdict from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( @@ -16,7 +16,6 @@ WithQueryBlock, ) from snowflake.snowpark._internal.utils import is_sql_select_statement -import snowflake.snowpark.context as context if TYPE_CHECKING: from snowflake.snowpark._internal.compiler.utils import TreeNode # pragma: no cover @@ -58,14 +57,6 @@ def find_duplicate_subtrees( # during this process invalid_ids_for_deduplication = set() - # When _is_snowpark_connect_compatible_mode is enabled, we track unique - # object identities per encoded_node_id to avoid merging nodes from - # different DataFrame construction calls that happen to produce - # identical SQL. Only the same Python object appearing multiple times - # (e.g. df.union_all(df)) should be treated as a duplicate. - use_object_identity = context._is_snowpark_connect_compatible_mode - object_ids_per_node_id: Dict[str, Set[int]] = defaultdict(set) - from snowflake.snowpark._internal.analyzer.select_statement import ( Selectable, SelectStatement, @@ -127,9 +118,6 @@ def traverse(root: "TreeNode") -> None: encoded_id = node.encoded_node_id_with_query id_node_map[encoded_id].append(node) - if use_object_identity: - object_ids_per_node_id[encoded_id].add(id(node)) - if is_select_from_file_node(node): invalid_ids_for_deduplication.add(encoded_id) @@ -150,28 +138,14 @@ def traverse(root: "TreeNode") -> None: current_level = next_level def _node_occurrence_count(encoded_node_id_with_query: str) -> int: - """How many times this node appears in the tree. - - In connect-compatible mode, different Python objects with the same - encoded id are counted as distinct nodes (occurrence = 1 each). - A true duplicate requires the *same* object referenced from - multiple parents. - - When multiple distinct objects share the same encoded id, we return - the max occurrence count among any single object. This handles cases - like union(union(df1, df1), union(df2, df2)) where df1 and df2 - produce identical SQL but df1 itself appears twice and should be - CTE-deduplicated. + """How many times this encoded node ID appears in the tree. + + This is a raw count based on encoded ID only. In connect-compatible + mode this may over-count (treating different Python objects with the + same SQL as duplicates). The per-object filtering is handled + downstream in _replace_duplicate_node_with_cte. """ - total = len(id_node_map[encoded_node_id_with_query]) - if use_object_identity: - object_ids = object_ids_per_node_id[encoded_node_id_with_query] - if len(object_ids) > 1: - id_counts = Counter( - id(node) for node in id_node_map[encoded_node_id_with_query] - ) - return max(id_counts.values()) - return total + return len(id_node_map[encoded_node_id_with_query]) def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: # when a sql query is a select statement, its encoded_node_id_with_query diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index e8287140a7..6bb66e5b2b 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # -from collections import defaultdict +from collections import Counter, defaultdict from typing import Dict, List, Optional, Set from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan @@ -117,18 +117,30 @@ def _replace_duplicate_node_with_cte( Replace all duplicated nodes with a WithQueryBlock (CTE node), to enable query generation with CTEs. + In connect-compatible mode, per-object occurrence counts are built + during the traversal so that only nodes whose specific Python object + appears more than once are CTE-replaced. Different objects that + share the same ``encoded_node_id_with_query`` but each appear only + once are left as inline subqueries. + NOTE, we use stack to perform a post-order traversal instead of recursive call. The reason of using the stack approach is that chained CTEs have to be built from bottom (innermost subquery) to top (outermost query). This function uses an iterative approach to avoid hitting Python's maximum recursion depth limit. """ + use_per_object = context._is_snowpark_connect_compatible_mode node_parents_map: Dict[TreeNode, Set[TreeNode]] = defaultdict(set) + # Per-object occurrence counts, built during the traversal below. + # Only populated in connect-compatible mode. + object_id_counts: Counter = Counter() stack1, stack2 = [root], [] while stack1: node = stack1.pop() stack2.append(node) + if use_per_object: + object_id_counts[id(node)] += 1 for child in reversed(node.children_plan_nodes): node_parents_map[child].add(node) stack1.append(child) @@ -136,8 +148,17 @@ def _replace_duplicate_node_with_cte( # track node that is already visited to avoid repeated operation on the same node visited_nodes: Set[TreeNode] = set() updated_nodes: Set[TreeNode] = set() - # track the resolved WithQueryBlock node has been created for each duplicated node + # track the resolved WithQueryBlock node has been created for each duplicated node. + # In per-object mode the key is str(id(node)); otherwise it is encoded_node_id_with_query. resolved_with_block_map: Dict[str, SnowflakePlan] = {} + # Counter for deterministic CTE naming when multiple distinct objects + # share the same encoded ID (connect mode only). + per_object_cte_counter: Dict[str, int] = defaultdict(int) + + def _map_key_for(node: TreeNode) -> str: + if use_per_object: + return str(id(node)) + return node.encoded_node_id_with_query def _update_parents( node: TreeNode, @@ -162,30 +183,36 @@ def _update_parents( # if the node is a duplicated node and deduplication is not done for the node, # start the deduplication transformation use CTE if node.encoded_node_id_with_query in duplicated_node_ids: - if node.encoded_node_id_with_query in resolved_with_block_map: - # if the corresponding CTE block has been created, use the existing - # one. - resolved_with_block = resolved_with_block_map[ - node.encoded_node_id_with_query - ] + if use_per_object and object_id_counts[id(node)] <= 1: + visited_nodes.add(node) + if node in updated_nodes: + _update_parents(node, should_replace_child=False) + continue + + map_key = _map_key_for(node) + if map_key in resolved_with_block_map: + resolved_with_block = resolved_with_block_map[map_key] else: if ( self._query_generator.session.reduce_describe_query_enabled and context._is_snowpark_connect_compatible_mode ): - # create a deterministic name using the first 10 chars of encoded_node_id_with_query (SHA256 hash) - # It helps when DataFrame.queries is called multiple times. - # Consistent CTE names returned, reducing the number of describe queries from cached_analyze_attributes calls. - cte_name = f"{TEMP_OBJECT_NAME_PREFIX}{TempObjectType.CTE.value}_{node.encoded_node_id_with_query[:HASH_LENGTH].upper()}" + base = f"{TEMP_OBJECT_NAME_PREFIX}{TempObjectType.CTE.value}_{node.encoded_node_id_with_query[:HASH_LENGTH].upper()}" + if use_per_object: + idx = per_object_cte_counter[ + node.encoded_node_id_with_query + ] + per_object_cte_counter[node.encoded_node_id_with_query] += 1 + cte_name = f"{base}_{idx}" if idx > 0 else base + else: + cte_name = base else: cte_name = random_name_for_temp_object(TempObjectType.CTE) with_block = WithQueryBlock(name=cte_name, child=node) # type: ignore with_block._is_valid_for_replacement = True resolved_with_block = self._query_generator.resolve(with_block) - resolved_with_block_map[ - node.encoded_node_id_with_query - ] = resolved_with_block + resolved_with_block_map[map_key] = resolved_with_block self._total_number_ctes += 1 _update_parents( node, should_replace_child=True, new_child=resolved_with_block diff --git a/tests/integ/test_cte.py b/tests/integ/test_cte.py index 49a2a4ef0f..b82f53d04c 100644 --- a/tests/integ/test_cte.py +++ b/tests/integ/test_cte.py @@ -585,6 +585,65 @@ def test_different_df_same_query(session, is_connect_mode): assert count_number_of_ctes(df.queries["queries"][-1]) == expected_cte_count +def test_mixed_duplicated_and_unique_objects_same_sql(session, is_connect_mode): + """ + union(union(df1, df1), union(df2, df3)) where df1, df2, df3 all produce + identical SQL but are different Python objects. + + In connect mode: + - df1 appears twice (same object) -> should be CTE-deduplicated + - df2 and df3 each appear once -> should NOT be CTE-deduplicated + - Expect 1 CTE (for df1 only) + In non-connect mode: + - All share the same encoded ID -> treated as one CTE + - Expect 1 CTE (all unified) + """ + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df3 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df = df1.union_all(df1).union_all(df2.union_all(df3)) + check_result( + session, + df, + expect_cte_optimized=True, + query_count=1, + describe_count=0, + union_count=3, + join_count=0, + ) + with SqlCounter(query_count=0, describe_count=0): + assert count_number_of_ctes(df.queries["queries"][-1]) == 1 + + +def test_distinct_objects_each_duplicated(session, is_connect_mode): + """ + union(union(df1, df1), union(df2, df2)) where df1 and df2 produce + identical SQL but are different Python objects. + + In connect mode: + - df1 appears twice, df2 appears twice -> each gets its own CTE + - Expect 2 CTEs + In non-connect mode: + - All share the same encoded ID -> one CTE + - Expect 1 CTE + """ + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df2 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).select("a") + df = df1.union_all(df1).union_all(df2.union_all(df2)) + check_result( + session, + df, + expect_cte_optimized=True, + query_count=1, + describe_count=0, + union_count=3, + join_count=0, + ) + with SqlCounter(query_count=0, describe_count=0): + expected_cte_count = 2 if is_connect_mode else 1 + assert count_number_of_ctes(df.queries["queries"][-1]) == expected_cte_count + + def test_same_duplicate_subtree(session): """ root diff --git a/tests/unit/test_cte.py b/tests/unit/test_cte.py index b6db420f8c..4e85fbb77f 100644 --- a/tests/unit/test_cte.py +++ b/tests/unit/test_cte.py @@ -115,7 +115,9 @@ def test_connect_mode_same_object_still_deduplicated(): def test_connect_mode_different_objects_same_id_not_deduplicated(): """When two different Python objects have the same encoded_node_id_with_query (e.g. df1.union_all(df2) where df1 and df2 produce the same SQL), - they should NOT be treated as duplicates in connect-compatible mode.""" + find_duplicate_subtrees flags the encoded ID (raw count > 1), but + the per-object filtering in _replace_duplicate_node_with_cte will + skip them since each object appears only once.""" root = _create_mock_node("root_R") child_a = _create_mock_node("same_S") child_b = _create_mock_node("same_S") @@ -129,18 +131,24 @@ def test_connect_mode_different_objects_same_id_not_deduplicated(): "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True ): duplicated_ids, _ = find_duplicate_subtrees(root) - assert len(duplicated_ids) == 0 + # The encoded IDs are flagged by raw count; per-object filtering + # happens downstream in _replace_duplicate_node_with_cte. + assert "same_S" in duplicated_ids def test_connect_mode_mixed_shared_and_distinct_objects(): """A tree with both shared objects (same ref) and distinct objects (different refs, - same encoded id). Only the shared object should be deduplicated in connect mode. + same encoded id). Tree: root / \\ left right (different objects, same encoded id "branch_B") | | shared shared (same object, appears twice → should be deduplicated) + + find_duplicate_subtrees flags both encoded IDs by raw count. + The per-object filtering in _replace_duplicate_node_with_cte will + skip left/right (each appears once) and only CTE-ify shared. """ root = _create_mock_node("root_R") left = _create_mock_node("branch_B") @@ -154,8 +162,11 @@ def test_connect_mode_mixed_shared_and_distinct_objects(): "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True ): duplicated_ids, _ = find_duplicate_subtrees(root) - assert "shared_S" in duplicated_ids - assert "branch_B" not in duplicated_ids + # branch_B has raw count 2 (left + right) → flagged as duplicate. + # shared_S has raw count 2 but its only parent (branch_B) is also + # duplicated, so it's not the root of a duplicate subtree. + assert "branch_B" in duplicated_ids + assert "shared_S" not in duplicated_ids def test_connect_mode_distinct_objects_each_duplicated(): @@ -216,6 +227,38 @@ def test_connect_mode_with_propagate_complexity_hist(): assert complexity_hist[1] == 2 # 50000 falls in bin 1 (> 10,000, <= 100,000) +def test_connect_mode_mixed_duplicated_and_unique_objects(): + """When multiple distinct objects share the same encoded id but only some + appear more than once, the encoded ID should still be flagged as + duplicated (because at least one object is genuinely duplicated). + + Tree: root + / \\ + union1 union2 + / \\ / \\ + df1 df1 df2 df3 ← df1 appears 2x (duplicated), df2 and df3 appear 1x each + + The encoded ID "same_S" should be in duplicated_node_ids because df1 + appears twice. The per-object filtering (df2/df3 not replaced) is + handled downstream in _replace_duplicate_node_with_cte. + """ + root = _create_mock_node("root_R") + union1 = _create_mock_node("union1_U") + union2 = _create_mock_node("union2_U2") + df1 = _create_mock_node("same_S") + df2 = _create_mock_node("same_S") + df3 = _create_mock_node("same_S") + root.children_plan_nodes = [union1, union2] + union1.children_plan_nodes = [df1, df1] + union2.children_plan_nodes = [df2, df3] + + with mock.patch( + "snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True + ): + duplicated_ids, _ = find_duplicate_subtrees(root) + assert "same_S" in duplicated_ids + + def test_encode_node_id_with_query_select_sql(mock_session, mock_analyzer): sql_text = "select 1 as a, 2 as b" select_sql_node = SelectSQL( From 2789b2d74b822ca78587cfa1cede88e69d34fb68 Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Mon, 6 Apr 2026 16:26:23 -0700 Subject: [PATCH 5/5] simplify implementation --- .../snowpark/_internal/compiler/cte_utils.py | 23 ++++--------- .../compiler/repeated_subquery_elimination.py | 33 ++++++++++--------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/snowflake/snowpark/_internal/compiler/cte_utils.py b/src/snowflake/snowpark/_internal/compiler/cte_utils.py index 7165cccf8e..f9b46f4276 100644 --- a/src/snowflake/snowpark/_internal/compiler/cte_utils.py +++ b/src/snowflake/snowpark/_internal/compiler/cte_utils.py @@ -115,14 +115,15 @@ def traverse(root: "TreeNode") -> None: while len(current_level) > 0: next_level = [] for node in current_level: - encoded_id = node.encoded_node_id_with_query - id_node_map[encoded_id].append(node) + id_node_map[node.encoded_node_id_with_query].append(node) if is_select_from_file_node(node): - invalid_ids_for_deduplication.add(encoded_id) + invalid_ids_for_deduplication.add(node.encoded_node_id_with_query) for child in node.children_plan_nodes: - id_parents_map[child.encoded_node_id_with_query].add(encoded_id) + id_parents_map[child.encoded_node_id_with_query].add( + node.encoded_node_id_with_query + ) next_level.append(child) current_level = next_level @@ -137,16 +138,6 @@ def traverse(root: "TreeNode") -> None: next_level.append(parent_id) current_level = next_level - def _node_occurrence_count(encoded_node_id_with_query: str) -> int: - """How many times this encoded node ID appears in the tree. - - This is a raw count based on encoded ID only. In connect-compatible - mode this may over-count (treating different Python objects with the - same SQL as duplicates). The per-object filtering is handled - downstream in _replace_duplicate_node_with_cte. - """ - return len(id_node_map[encoded_node_id_with_query]) - def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: # when a sql query is a select statement, its encoded_node_id_with_query # contains _, which is used to separate the query id and node type name. @@ -163,10 +154,10 @@ def is_duplicate_subtree(encoded_node_id_with_query: str) -> bool: if encoded_node_id_with_query in invalid_ids_for_deduplication: return False - is_duplicate_node = _node_occurrence_count(encoded_node_id_with_query) > 1 + is_duplicate_node = len(id_node_map[encoded_node_id_with_query]) > 1 if is_duplicate_node: is_any_parent_unique_node = any( - _node_occurrence_count(id) == 1 + len(id_node_map[id]) == 1 for id in id_parents_map[encoded_node_id_with_query] ) if is_any_parent_unique_node: diff --git a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py index 6bb66e5b2b..1c32af1f83 100644 --- a/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py +++ b/src/snowflake/snowpark/_internal/compiler/repeated_subquery_elimination.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +import hashlib from collections import Counter, defaultdict from typing import Dict, List, Optional, Set @@ -151,14 +152,6 @@ def _replace_duplicate_node_with_cte( # track the resolved WithQueryBlock node has been created for each duplicated node. # In per-object mode the key is str(id(node)); otherwise it is encoded_node_id_with_query. resolved_with_block_map: Dict[str, SnowflakePlan] = {} - # Counter for deterministic CTE naming when multiple distinct objects - # share the same encoded ID (connect mode only). - per_object_cte_counter: Dict[str, int] = defaultdict(int) - - def _map_key_for(node: TreeNode) -> str: - if use_per_object: - return str(id(node)) - return node.encoded_node_id_with_query def _update_parents( node: TreeNode, @@ -189,7 +182,11 @@ def _update_parents( _update_parents(node, should_replace_child=False) continue - map_key = _map_key_for(node) + map_key = ( + node.encoded_node_id_with_query + if not use_per_object + else str(id(node)) + ) if map_key in resolved_with_block_map: resolved_with_block = resolved_with_block_map[map_key] else: @@ -197,15 +194,19 @@ def _update_parents( self._query_generator.session.reduce_describe_query_enabled and context._is_snowpark_connect_compatible_mode ): - base = f"{TEMP_OBJECT_NAME_PREFIX}{TempObjectType.CTE.value}_{node.encoded_node_id_with_query[:HASH_LENGTH].upper()}" if use_per_object: - idx = per_object_cte_counter[ - node.encoded_node_id_with_query - ] - per_object_cte_counter[node.encoded_node_id_with_query] += 1 - cte_name = f"{base}_{idx}" if idx > 0 else base + obj_hash = ( + hashlib.sha256( + f"{node.encoded_node_id_with_query}:{id(node)}".encode() + ) + .hexdigest()[:HASH_LENGTH] + .upper() + ) else: - cte_name = base + obj_hash = node.encoded_node_id_with_query[ + :HASH_LENGTH + ].upper() + cte_name = f"{TEMP_OBJECT_NAME_PREFIX}{TempObjectType.CTE.value}_{obj_hash}" else: cte_name = random_name_for_temp_object(TempObjectType.CTE) with_block = WithQueryBlock(name=cte_name, child=node) # type: ignore