diff --git a/textworld/generator/chaining.py b/textworld/generator/chaining.py index 4a9c953a..6b5f8ddb 100644 --- a/textworld/generator/chaining.py +++ b/textworld/generator/chaining.py @@ -76,6 +76,8 @@ class ChainingOptions: subquests: Whether to also return incomplete quests, which could be extended without reaching the depth or breadth limits. + independent_chains: + Whether to allow totally independent parallel chains. create_variables: Whether new variables may be created during chaining. fixed_mapping: @@ -99,6 +101,7 @@ def __init__(self): self.min_breadth = 1 self.max_breadth = 1 self.subquests = False + self.independent_chains = False self.create_variables = False self.fixed_mapping = data.get_types().constants_mapping self.rng = None @@ -189,30 +192,17 @@ class _Node: A node in a chain being generated. Each node is aware of its position (depth, breadth) in the dependency tree - induced by the chain. For generating parallel quests, the backtracks field - holds actions that can be use to go up the dependency tree and start a new - chain. - - For example, taking the action node.backtracks[i][j] will produce a new node - at depth (i + 1) and breadth (self.breadth + 1). To avoid duplication, in - trees like this: - - root - / | \ - A B C - | | | - ....... - - A.backtracks[0] will be [B, C], B.backtracks[0] will be [C], and - C.backtracks[0] will be []. + induced by the chain. To avoid duplication when generating parallel chains, + each node stores the actions that have already been used at that depth. """ - def __init__(self, parent, dep_parent, state, action, backtracks, depth, breadth): + def __init__(self, parent, dep_parent, state, action, rules, used, depth, breadth): self.parent = parent self.dep_parent = dep_parent self.state = state self.action = action - self.backtracks = backtracks + self.rules = rules + self.used = used self.depth = depth self.breadth = breadth @@ -235,7 +225,7 @@ def __init__(self, state, options): def root(self) -> _Node: """Create the root node for chaining.""" - return _Node(None, None, self.state, None, [], 0, 1) + return _Node(None, None, self.state, None, [], set(), 0, 1) def chain(self, node: _Node) -> Iterable[_Node]: """ @@ -251,30 +241,21 @@ def chain(self, node: _Node) -> Iterable[_Node]: if self.rng: self.rng.shuffle(assignments) - partials = [] - actions = [] - states = [] + used = set() for partial in assignments: action = self.try_instantiate(node.state, partial) if not action: continue - if not self.check_action(node, action): + if not self.check_action(node, node.state, action): continue state = self.apply(node, action) if not state: continue - partials.append(partial) - actions.append(action) - states.append(state) - - for i, action in enumerate(actions): - # Only allow backtracking into later actions, to avoid duplication - remaining = partials[i+1:] - backtracks = node.backtracks + [remaining] - yield _Node(node, node, states[i], action, backtracks, node.depth + 1, node.breadth) + used = used | {action} + yield _Node(node, node, state, action, rules, used, node.depth + 1, node.breadth) def backtrack(self, node: _Node) -> Iterable[_Node]: """ @@ -284,21 +265,39 @@ def backtrack(self, node: _Node) -> Iterable[_Node]: if node.breadth >= self.max_breadth: return - for i, partials in enumerate(node.backtracks): - backtracks = node.backtracks[:i] - - for j, partial in enumerate(partials): + parent = node + parents = [] + while parent.dep_parent: + if parent.depth == 1 and not self.options.independent_chains: + break + parents.append(parent) + parent = parent.dep_parent + parents = parents[::-1] + + for sibling in parents: + parent = sibling.dep_parent + rules = self.options.get_rules(parent.depth) + assignments = self.all_assignments(node, rules) + if self.rng: + self.rng.shuffle(assignments) + + for partial in assignments: action = self.try_instantiate(node.state, partial) if not action: continue + if action in sibling.used: + continue + + if not self.check_action(parent, node.state, action): + continue + state = self.apply(node, action) if not state: continue - remaining = partials[j+1:] - new_backtracks = backtracks + [remaining] - yield _Node(node, partial.node, state, action, new_backtracks, i + 1, node.breadth + 1) + used = sibling.used | {action} + yield _Node(node, parent, state, action, rules, used, sibling.depth, node.breadth + 1) def all_assignments(self, node: _Node, rules: Iterable[Rule]) -> Iterable[_PartialAction]: """ @@ -359,7 +358,7 @@ def create_variable(self, state, ph, type_counts): type_counts[ph.type] += 1 return var - def check_action(self, node: _Node, action: Action) -> bool: + def check_action(self, node: _Node, state: State, action: Action) -> bool: # Find the last action before a navigation action # TODO: Fold this behaviour into ChainingOptions.check_action() nav_parent = node @@ -387,7 +386,7 @@ def check_action(self, node: _Node, action: Action) -> bool: if len(recent.added & relevant) == 0 or len(pre_navigation.added & relevant) == 0: return False - return self.options.check_action(node.state, action) + return self.options.check_action(state, action) def _is_navigation(self, action): return action.name.startswith("go/") @@ -405,8 +404,8 @@ def apply(self, node: _Node, action: Action) -> Optional[State]: new_state.apply(action) - # Some debug checks - assert self.check_state(new_state) + if not self.check_state(new_state): + return None # Detect cycles state = new_state.copy() diff --git a/textworld/generator/tests/test_chaining.py b/textworld/generator/tests/test_chaining.py index 2477525b..b2b555c5 100644 --- a/textworld/generator/tests/test_chaining.py +++ b/textworld/generator/tests/test_chaining.py @@ -237,4 +237,73 @@ def test_parallel_quests(): options.min_breadth = 1 options.create_variables = True chains = list(get_chains(State(), options)) - assert len(chains) == 6 + assert len(chains) == 5 + + +def test_parallel_quests_navigation(): + logic = GameLogic.parse(""" + type P { + } + + type I { + } + + type r { + rules { + move :: at(P, r) & $free(r, r') -> at(P, r'); + } + + constraints { + atat :: at(P, r) & at(P, r') -> fail(); + } + } + + type o { + rules { + take :: $at(P, r) & at(o, r) -> in(o, I); + } + + constraints { + inat :: in(o, I) & at(o, r) -> fail(); + } + } + + type flour : o { + } + + type eggs : o { + } + + type cake { + rules { + bake :: in(flour, I) & in(eggs, I) -> in(cake, I) & in(flour, cake) & in(eggs, cake); + } + + constraints { + inincake :: in(o, I) & in(o, cake) -> fail(); + atincake :: at(o, r) & in(o, cake) -> fail(); + } + } + """) + + state = State([ + Proposition.parse("at(P, r3: r)"), + Proposition.parse("free(r2: r, r3: r)"), + Proposition.parse("free(r1: r, r2: r)"), + ]) + + bake = [logic.rules["bake"]] + non_bake = [r for r in logic.rules.values() if r.name != "bake"] + + options = ChainingOptions() + options.backward = True + options.create_variables = True + options.min_depth = 3 + options.max_depth = 3 + options.min_breadth = 2 + options.max_breadth = 2 + options.logic = logic + options.rules_per_depth = [bake, non_bake, non_bake] + options.restricted_types = {"P", "r"} + chains = list(get_chains(state, options)) + assert len(chains) == 2