Skip to content

Commit cb128b0

Browse files
committed
Fail fast scan memory inplace
1 parent 6cf5842 commit cb128b0

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pytensor/scan/rewriting.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,8 +1048,13 @@ def attempt_scan_inplace(
10481048
return None
10491049

10501050
def apply(self, fgraph):
1051+
scan_nodes = {node for node in fgraph.apply_nodes if isinstance(node.op, Scan)}
1052+
1053+
if not scan_nodes:
1054+
return
1055+
10511056
for scan_idx, original_node in enumerate(reversed(fgraph.toposort())):
1052-
if not isinstance(original_node.op, Scan):
1057+
if original_node not in scan_nodes:
10531058
continue
10541059

10551060
# First attempt to make the Scan compute inplace every recurrent

0 commit comments

Comments
 (0)