Skip to content

Commit 330a1e0

Browse files
committed
Revert "Use non recursive algorithm in rebuild_collect_shared"
This reverts commit ae43c14. Breaks the @pytensor_jit example
1 parent 90ab712 commit 330a1e0

File tree

1 file changed

+40
-47
lines changed

1 file changed

+40
-47
lines changed

pytensor/compile/function/pfunc.py

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -179,54 +179,47 @@ def clone_v_get_shared_updates(v, copy_inputs_over):
179179
180180
"""
181181
# this co-recurses with clone_a
182-
stack = [v]
183-
try:
184-
while True:
185-
v = stack.pop()
186-
if v in clone_d:
187-
continue
188-
if (apply := v.owner) is not None:
189-
if all(i in clone_d for i in apply.inputs):
190-
# all inputs have been cloned, we can clone this node
191-
clone_node_and_cache(
192-
apply,
193-
clone_d,
194-
strict=rebuild_strict,
195-
clone_inner_graphs=clone_inner_graphs,
182+
assert v is not None
183+
if v in clone_d:
184+
return clone_d[v]
185+
if v.owner:
186+
owner = v.owner
187+
if owner not in clone_d:
188+
for i in owner.inputs:
189+
clone_v_get_shared_updates(i, copy_inputs_over)
190+
clone_node_and_cache(
191+
owner,
192+
clone_d,
193+
strict=rebuild_strict,
194+
clone_inner_graphs=clone_inner_graphs,
195+
)
196+
return clone_d.setdefault(v, v)
197+
elif isinstance(v, SharedVariable):
198+
if v not in shared_inputs:
199+
shared_inputs.append(v)
200+
if v.default_update is not None:
201+
# Check that v should not be excluded from the default
202+
# updates list
203+
if no_default_updates is False or (
204+
isinstance(no_default_updates, list) and v not in no_default_updates
205+
):
206+
# Do not use default_update if a "real" update was
207+
# provided
208+
if v not in update_d:
209+
v_update = v.type.filter_variable(
210+
v.default_update, allow_convert=False
196211
)
197-
else:
198-
# expand on the inputs
199-
stack.extend(apply.inputs)
200-
else:
201-
clone_d[v] = v if copy_inputs_over else v.clone()
202-
203-
# Special handling of SharedVariables
204-
if isinstance(v, SharedVariable):
205-
if v not in shared_inputs:
206-
shared_inputs.append(v)
207-
if v.default_update is not None:
208-
# Check that v should not be excluded from the default
209-
# updates list
210-
if no_default_updates is False or (
211-
isinstance(no_default_updates, list)
212-
and v not in no_default_updates
213-
):
214-
# Do not use default_update if a "real" update was
215-
# provided
216-
if v not in update_d:
217-
v_update = v.type.filter_variable(
218-
v.default_update, allow_convert=False
219-
)
220-
if not v.type.is_super(v_update.type):
221-
raise TypeError(
222-
"An update must have a type compatible with "
223-
"the original shared variable"
224-
)
225-
update_d[v] = v_update
226-
update_expr.append((v, v_update))
227-
except IndexError:
228-
pass # stack is empty
229-
return clone_d[v]
212+
if not v.type.is_super(v_update.type):
213+
raise TypeError(
214+
"An update must have a type compatible with "
215+
"the original shared variable"
216+
)
217+
update_d[v] = v_update
218+
update_expr.append((v, v_update))
219+
if not copy_inputs_over:
220+
return clone_d.setdefault(v, v.clone())
221+
else:
222+
return clone_d.setdefault(v, v)
230223

231224
# initialize the clone_d mapping with the replace dictionary
232225
if replace is None:

0 commit comments

Comments
 (0)