@@ -203,8 +203,8 @@ class _Node:
203203 | | |
204204 .......
205205
206- A.backtracks[0] will be [B, C], B.backtracks[0] will be [C] , and
207- C.backtracks[0] will be [] .
206+ A.backtracks[0] will be {A}, B.backtracks[0] will be {A, B} , and
207+ C.backtracks[0] will be {A, B, C} .
208208 """
209209
210210 def __init__ (self , parent , dep_parent , state , action , backtracks , depth , breadth ):
@@ -259,7 +259,7 @@ def chain(self, node: _Node) -> Iterable[_Node]:
259259 if not action :
260260 continue
261261
262- if not self .check_action (node , action ):
262+ if not self .check_action (node , node . state , action ):
263263 continue
264264
265265 state = self .apply (node , action )
@@ -272,8 +272,8 @@ def chain(self, node: _Node) -> Iterable[_Node]:
272272
273273 for i , action in enumerate (actions ):
274274 # Only allow backtracking into later actions, to avoid duplication
275- remaining = partials [ i + 1 :]
276- backtracks = node .backtracks + [remaining ]
275+ used = set ( actions [: i + 1 ])
276+ backtracks = node .backtracks + [used ]
277277 yield _Node (node , node , states [i ], action , backtracks , node .depth + 1 , node .breadth )
278278
279279 def backtrack (self , node : _Node ) -> Iterable [_Node ]:
@@ -284,21 +284,37 @@ def backtrack(self, node: _Node) -> Iterable[_Node]:
284284 if node .breadth >= self .max_breadth :
285285 return
286286
287- for i , partials in enumerate (node .backtracks ):
288- backtracks = node .backtracks [:i ]
287+ parent = node .dep_parent
288+ parents = []
289+ while parent :
290+ parents .append (parent )
291+ parent = parent .dep_parent
292+ parents = parents [::- 1 ]
289293
290- for j , partial in enumerate (partials ):
294+ for parent in parents :
295+ rules = self .options .get_rules (parent .depth )
296+ assignments = self .all_assignments (node , rules )
297+ if self .rng :
298+ self .rng .shuffle (assignments )
299+
300+ for partial in assignments :
291301 action = self .try_instantiate (node .state , partial )
292302 if not action :
293303 continue
294304
305+ used = node .backtracks [parent .depth ]
306+ if action in used :
307+ continue
308+
309+ if not self .check_action (parent , node .state , action ):
310+ continue
311+
295312 state = self .apply (node , action )
296313 if not state :
297314 continue
298315
299- remaining = partials [j + 1 :]
300- new_backtracks = backtracks + [remaining ]
301- yield _Node (node , partial .node , state , action , new_backtracks , i + 1 , node .breadth + 1 )
316+ backtracks = node .backtracks [:parent .depth ] + [used | {action }]
317+ yield _Node (node , parent , state , action , backtracks , parent .depth + 1 , node .breadth + 1 )
302318
303319 def all_assignments (self , node : _Node , rules : Iterable [Rule ]) -> Iterable [_PartialAction ]:
304320 """
@@ -359,7 +375,7 @@ def create_variable(self, state, ph, type_counts):
359375 type_counts [ph .type ] += 1
360376 return var
361377
362- def check_action (self , node : _Node , action : Action ) -> bool :
378+ def check_action (self , node : _Node , state : State , action : Action ) -> bool :
363379 # Find the last action before a navigation action
364380 # TODO: Fold this behaviour into ChainingOptions.check_action()
365381 nav_parent = node
@@ -387,7 +403,7 @@ def check_action(self, node: _Node, action: Action) -> bool:
387403 if len (recent .added & relevant ) == 0 or len (pre_navigation .added & relevant ) == 0 :
388404 return False
389405
390- return self .options .check_action (node . state , action )
406+ return self .options .check_action (state , action )
391407
392408 def _is_navigation (self , action ):
393409 return action .name .startswith ("go/" )
@@ -406,6 +422,9 @@ def apply(self, node: _Node, action: Action) -> Optional[State]:
406422 new_state .apply (action )
407423
408424 # Some debug checks
425+ # XXX
426+ if not self .check_state (new_state ):
427+ return None
409428 assert self .check_state (new_state )
410429
411430 # Detect cycles
0 commit comments