diff --git a/tests/test_core.py b/tests/test_core.py index 505c81a..8f50437 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,7 +2,13 @@ from collections import OrderedDict from unification import var -from unification.core import reify, unify, unground_lvars, isground +from unification.core import ( + reify, + unify, + unground_lvars, + isground, + ContractingAssociationMap, +) def test_reify(): @@ -139,3 +145,82 @@ def test_unground_lvars(): assert not isground( ctor((a_lv, sub_ctor((b_lv, 2)), 3)), {a_lv: b_lv, b_lv: var("c")} ) + + +def test_ContractingAssociationMap(): + + a, b, c, d = var("a"), var("b"), var("c"), var("d") + + # Contractions should happen in the constructor + m = ContractingAssociationMap({b: c, c: a, d: d}) + assert m == {b: a, c: a} + + # Order of entry shouldn't matter + m = ContractingAssociationMap([(b, c), (c, a), (d, d)]) + assert m == {b: a, c: a} + + m = ContractingAssociationMap([(c, a), (b, c), (d, d)]) + assert m == {b: a, c: a} + + # Nor should the means of entry + m = ContractingAssociationMap() + m[a] = b + m[b] = c + + assert m == {b: c, a: c} + + m = ContractingAssociationMap() + m[b] = c + m[a] = b + + assert m == {b: c, a: c} + + # Make sure we don't introduce cycles, and that we remove newly imposed + # ones + m[c] = a + assert m == {b: a, c: a} + + m = ContractingAssociationMap([(b, c), (c, b), (d, d)]) + assert m == {c: b} + + m = ContractingAssociationMap([(c, b), (b, c), (d, d)]) + assert m == {b: c} + + # Simulate a long chain + import timeit + + dict_time = timeit.timeit( + stmt=""" + from unification import var, reify + from unification.utils import transitive_get as walk + + m = {} + first_lvar = var() + lvar = first_lvar + for i in range(1000): + m[lvar] = var() + lvar = m[lvar] + m[lvar] = 1 + + assert walk(first_lvar, m) == 1 + """, + number=10, + ) + + cmap_time = timeit.timeit( + stmt=""" + from unification import var, reify + from unification.core import ContractingAssociationMap + + m = ContractingAssociationMap() + first_lvar = var() + lvar = first_lvar + for i in range(1000): + m[lvar] = var() + lvar = m[lvar] + m[lvar] = 1 + + assert m[first_lvar] == 1 + """, + number=10, + ) diff --git a/unification/core.py b/unification/core.py index 436d8db..c5d07cf 100644 --- a/unification/core.py +++ b/unification/core.py @@ -1,7 +1,8 @@ from toolz import assoc from operator import length_hint from functools import partial -from collections import OrderedDict + +from collections import OrderedDict, UserDict from collections.abc import Iterator, Mapping, Set from .utils import transitive_get as walk @@ -13,6 +14,76 @@ class UngroundLVarException(Exception): """An exception signaling that an unground variables was found.""" +class ContractingAssociationMap(UserDict): + """A map that contracts association chains. + + For instance, if we add the logic variable association a -> b to a + `ContractingAssociationMap` containing a logic variable associationn like + {c -> a, ...}, the result will be {c -> b, a -> b, ...}. + """ + + def __init__(self, *args, **kwargs): + self.inverse = {} + super().__init__(*args, **kwargs) + # TODO + # self.data = WeakKeyDictionary(self.data) + + def __setitem__(self, key, value): + + assert isvar(key) + + # Self-associations are a waste + if key == value: + return + + # Get the (one-step) walked value + walk_value = self.data.get(value, value) + + # Get the keys that have this key as a value + key_keys = self.inverse.setdefault(key, []) + + if walk_value == key: + # The walked value equals the key, so we need to drop the `value` + # association before we loop and reintroduce it + del self[value] + walk_value = value + + for key_key in tuple(key_keys): + + # old_val = self.data[key_key] + + # Remove the old mapping and its inverse + del self[key_key] + + # Don't add self associations + if (key_key == value) or (key_key == walk_value): + continue + + # Remap those association keys to this lvar + # TODO: Flatten-out this tail recursion? + self[key_key] = walk_value + + if isvar(walk_value): + # Add the new association's inverse + self.inverse.setdefault(walk_value, []).append(key) + + super().__setitem__(key, walk_value) + + def __delitem__(self, key): + + val = self.data[key] + + if isvar(val): + + # Remove the inverse association + self.inverse.setdefault(val, []).remove(key) + + if val in self.inverse and not self.inverse[val]: + del self.inverse[val] + + super().__delitem__(key) + + @dispatch(object, Mapping) def _reify(o, s): return o