Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

from collections import defaultdict
import hashlib
from collections import Counter, defaultdict
from typing import Dict, List, Optional, Set

from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan
Expand Down Expand Up @@ -117,26 +118,39 @@ def _replace_duplicate_node_with_cte(
Replace all duplicated nodes with a WithQueryBlock (CTE node), to enable
query generation with CTEs.

In connect-compatible mode, per-object occurrence counts are built
during the traversal so that only nodes whose specific Python object
appears more than once are CTE-replaced. Different objects that
share the same ``encoded_node_id_with_query`` but each appear only
once are left as inline subqueries.

NOTE, we use stack to perform a post-order traversal instead of recursive call.
The reason of using the stack approach is that chained CTEs have to be built
from bottom (innermost subquery) to top (outermost query).
This function uses an iterative approach to avoid hitting Python's maximum recursion depth limit.
"""
use_per_object = context._is_snowpark_connect_compatible_mode

node_parents_map: Dict[TreeNode, Set[TreeNode]] = defaultdict(set)
# Per-object occurrence counts, built during the traversal below.
# Only populated in connect-compatible mode.
object_id_counts: Counter = Counter()
stack1, stack2 = [root], []

while stack1:
node = stack1.pop()
stack2.append(node)
if use_per_object:
object_id_counts[id(node)] += 1
for child in reversed(node.children_plan_nodes):
node_parents_map[child].add(node)
stack1.append(child)

# track node that is already visited to avoid repeated operation on the same node
visited_nodes: Set[TreeNode] = set()
updated_nodes: Set[TreeNode] = set()
# track the resolved WithQueryBlock node has been created for each duplicated node
# track the resolved WithQueryBlock node has been created for each duplicated node.
# In per-object mode the key is str(id(node)); otherwise it is encoded_node_id_with_query.
resolved_with_block_map: Dict[str, SnowflakePlan] = {}

def _update_parents(
Expand All @@ -162,30 +176,44 @@ def _update_parents(
# if the node is a duplicated node and deduplication is not done for the node,
# start the deduplication transformation use CTE
if node.encoded_node_id_with_query in duplicated_node_ids:
if node.encoded_node_id_with_query in resolved_with_block_map:
# if the corresponding CTE block has been created, use the existing
# one.
resolved_with_block = resolved_with_block_map[
node.encoded_node_id_with_query
]
if use_per_object and object_id_counts[id(node)] <= 1:
visited_nodes.add(node)
if node in updated_nodes:
_update_parents(node, should_replace_child=False)
continue

map_key = (
node.encoded_node_id_with_query
if not use_per_object
else str(id(node))
)
if map_key in resolved_with_block_map:
resolved_with_block = resolved_with_block_map[map_key]
else:
if (
self._query_generator.session.reduce_describe_query_enabled
and context._is_snowpark_connect_compatible_mode
):
# create a deterministic name using the first 10 chars of encoded_node_id_with_query (SHA256 hash)
# It helps when DataFrame.queries is called multiple times.
# Consistent CTE names returned, reducing the number of describe queries from cached_analyze_attributes calls.
cte_name = f"{TEMP_OBJECT_NAME_PREFIX}{TempObjectType.CTE.value}_{node.encoded_node_id_with_query[:HASH_LENGTH].upper()}"
if use_per_object:
obj_hash = (
hashlib.sha256(
f"{node.encoded_node_id_with_query}:{id(node)}".encode()
)
.hexdigest()[:HASH_LENGTH]
.upper()
)
else:
obj_hash = node.encoded_node_id_with_query[
:HASH_LENGTH
].upper()
cte_name = f"{TEMP_OBJECT_NAME_PREFIX}{TempObjectType.CTE.value}_{obj_hash}"
else:
cte_name = random_name_for_temp_object(TempObjectType.CTE)
with_block = WithQueryBlock(name=cte_name, child=node) # type: ignore
with_block._is_valid_for_replacement = True

resolved_with_block = self._query_generator.resolve(with_block)
resolved_with_block_map[
node.encoded_node_id_with_query
] = resolved_with_block
resolved_with_block_map[map_key] = resolved_with_block
self._total_number_ctes += 1
_update_parents(
node, should_replace_child=True, new_child=resolved_with_block
Expand Down
Loading
Loading