From 063bd51de18924deb3fcb5cb225ac31e98092295 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:37 -0600 Subject: [PATCH 01/14] Rename to test function to avoid confusion with package function --- tests/test_graph.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index 965240d..51e50da 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -71,7 +71,7 @@ def single_math_reduceo(expanded_term, reduced_term): math_reduceo = partial(reduceo, single_math_reduceo) -term_walko = partial( +walko_term = partial( walko, rator_goal=eq, null_type=ExpressionTuple, @@ -413,11 +413,11 @@ def test_walko(test_input, test_output): """Test `walko` with fully ground terms (i.e. no logic variables).""" q_lv = var() - term_walko_fp = partial(reduceo, partial(term_walko, single_math_reduceo)) + walko_term_fp = partial(reduceo, partial(walko_term, single_math_reduceo)) test_res = run( len(test_output), q_lv, - term_walko_fp(test_input, q_lv), + walko_term_fp(test_input, q_lv), results_filter=toolz.unique, ) @@ -438,7 +438,7 @@ def test_walko_reverse(): """Test `walko` in "reverse" (i.e. specify the reduced form and generate the un-reduced form).""" # noqa: E501 q_lv = var("q") - test_res = run(2, q_lv, term_walko(math_reduceo, q_lv, 5)) + test_res = run(2, q_lv, walko_term(math_reduceo, q_lv, 5)) assert test_res == ( etuple(log, etuple(exp, 5)), etuple(log, etuple(exp, etuple(log, etuple(exp, 5)))), @@ -446,7 +446,7 @@ def test_walko_reverse(): assert all(e.eval_obj == 5.0 for e in test_res) # Make sure we get some variety in the results - test_res = run(2, q_lv, term_walko(math_reduceo, q_lv, etuple(mul, 2, 5))) + test_res = run(2, q_lv, walko_term(math_reduceo, q_lv, etuple(mul, 2, 5))) assert test_res == ( # Expansion of the term's root etuple(add, 5, 5), @@ -460,7 +460,7 @@ def test_walko_reverse(): assert all(e.eval_obj == 10.0 for e in test_res) r_lv = var("r") - test_res = run(4, [q_lv, r_lv], term_walko(math_reduceo, q_lv, r_lv)) + test_res = run(4, [q_lv, r_lv], walko_term(math_reduceo, q_lv, r_lv)) expect_res = ( [etuple(add, 1, 1), etuple(mul, 2, 1)], [etuple(log, etuple(exp, etuple(add, 1, 1))), etuple(mul, 2, 1)], From f1b1350bd0201bd334ee53549dc49a5b730622cf Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:40 -0600 Subject: [PATCH 02/14] Clean up duplicate test code --- tests/test_assoccomm.py | 76 ++++++--------------------------- tests/test_term.py | 94 +++++++++++------------------------------ tests/utils.py | 60 ++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 133 deletions(-) create mode 100644 tests/utils.py diff --git a/tests/test_assoccomm.py b/tests/test_assoccomm.py index 3795701..90dddd1 100644 --- a/tests/test_assoccomm.py +++ b/tests/test_assoccomm.py @@ -1,5 +1,3 @@ -from collections.abc import Sequence - import pytest from cons import cons from etuples.core import etuple @@ -18,60 +16,7 @@ ) from kanren.core import run from kanren.facts import fact -from kanren.term import arguments, operator, term - - -class Node(object): - def __init__(self, op, args): - self.op = op - self.args = args - - def __eq__(self, other): - return ( - type(self) == type(other) - and self.op == other.op - and self.args == other.args - ) - - def __hash__(self): - return hash((type(self), self.op, self.args)) - - def __str__(self): - return "%s(%s)" % (self.op.name, ", ".join(map(str, self.args))) - - __repr__ = __str__ - - -class Operator(object): - def __init__(self, name): - self.name = name - - -Add = Operator("add") -Mul = Operator("mul") - - -def add(*args): - return Node(Add, args) - - -def mul(*args): - return Node(Mul, args) - - -@term.register(Operator, Sequence) -def term_Operator(op, args): - return Node(op, args) - - -@arguments.register(Node) -def arguments_Node(n): - return n.args - - -@operator.register(Node) -def operator_Node(n): - return n.op +from tests.utils import Add def results(g, s=None): @@ -174,15 +119,15 @@ def test_eq_comm(): @pytest.mark.xfail(reason="`applyo`/`buildo` needs to be a constraint.", strict=True) def test_eq_comm_object(): - x = var("x") + x = var() fact(commutative, Add) fact(associative, Add) - assert run(0, x, eq_comm(add(1, 2, 3), add(3, 1, x))) == (2,) - assert set(run(0, x, eq_comm(add(1, 2), x))) == set((add(1, 2), add(2, 1))) - assert set(run(0, x, eq_assoccomm(add(1, 2, 3), add(1, x)))) == set( - (add(2, 3), add(3, 2)) + assert run(0, x, eq_comm(Add(1, 2, 3), Add(3, 1, x))) == (2,) + assert set(run(0, x, eq_comm(Add(1, 2), x))) == set((Add(1, 2), Add(2, 1))) + assert set(run(0, x, eq_assoccomm(Add(1, 2, 3), Add(1, x)))) == set( + (Add(2, 3), Add(3, 2)) ) @@ -575,6 +520,9 @@ def test_assoccomm_objects(): x = var() - assert run(0, True, eq_assoccomm(add(1, 2, 3), add(3, 1, 2))) == (True,) - assert run(0, x, eq_assoccomm(add(1, 2, 3), add(1, 2, x))) == (3,) - assert run(0, x, eq_assoccomm(add(1, 2, 3), add(x, 2, 1))) == (3,) + assert run(0, True, eq_assoccomm(Add(1, 2, 3), Add(3, 1, 2))) == (True,) + # FYI: If `Node` is made `unifiable_with_term` (along with `term`, + # `operator`, and `arguments` implementations), you'll get duplicate + # results in the following test (i.e. `(3, 3)`). + assert run(0, x, eq_assoccomm(Add(1, 2, 3), Add(1, 2, x))) == (3,) + assert run(0, x, eq_assoccomm(Add(1, 2, 3), Add(x, 2, 1))) == (3,) diff --git a/tests/test_term.py b/tests/test_term.py index 740b645..8c79203 100644 --- a/tests/test_term.py +++ b/tests/test_term.py @@ -3,66 +3,8 @@ from unification import reify, unify, var from kanren.core import run -from kanren.term import applyo, arguments, operator, term, unifiable_with_term - - -@unifiable_with_term -class Node(object): - def __init__(self, op, args): - self.op = op - self.args = args - - def __eq__(self, other): - return ( - type(self) == type(other) - and self.op == other.op - and self.args == other.args - ) - - def __hash__(self): - return hash((type(self), self.op, self.args)) - - def __str__(self): - return "%s(%s)" % (self.op.name, ", ".join(map(str, self.args))) - - __repr__ = __str__ - - -class Operator(object): - def __init__(self, name): - self.name = name - - -Add = Operator("add") -Mul = Operator("mul") - - -def add(*args): - return Node(Add, args) - - -def mul(*args): - return Node(Mul, args) - - -class Op(object): - def __init__(self, name): - self.name = name - - -@arguments.register(Node) -def arguments_Node(t): - return t.args - - -@operator.register(Node) -def operator_Node(t): - return t.op - - -@term.register(Operator, (list, tuple)) -def term_Op(op, args): - return Node(op, args) +from kanren.term import applyo, arguments, operator, term +from tests.utils import Add, Node, Operator def test_applyo(): @@ -103,21 +45,35 @@ def test_applyo(): def test_applyo_object(): x = var() - assert run(0, x, applyo(Add, (1, 2, 3), x)) == (add(1, 2, 3),) - assert run(0, x, applyo(x, (1, 2, 3), add(1, 2, 3))) == (Add,) - assert run(0, x, applyo(Add, x, add(1, 2, 3))) == ((1, 2, 3),) + assert run(0, x, applyo(Add, (1, 2, 3), x)) == (Add(1, 2, 3),) + assert run(0, x, applyo(x, (1, 2, 3), Add(1, 2, 3))) == (Add,) + assert run(0, x, applyo(Add, x, Add(1, 2, 3))) == ((1, 2, 3),) -def test_unifiable_with_term(): - add = Operator("add") - t = Node(add, (1, 2)) +def test_term_dispatch(): + + t = Node(Add, (1, 2)) assert arguments(t) == (1, 2) - assert operator(t) == add + assert operator(t) == Add assert term(operator(t), arguments(t)) == t + +def test_unifiable_with_term(): + from kanren.term import unifiable_with_term + + @unifiable_with_term + class NewNode(Node): + pass + + class NewOperator(Operator): + def __call__(self, *args): + return NewNode(self, args) + + NewAdd = NewOperator("newadd") + x = var() - s = unify(Node(add, (1, x)), Node(add, (1, 2)), {}) + s = unify(NewNode(NewAdd, (1, x)), NewNode(NewAdd, (1, 2)), {}) assert s == {x: 2} - assert reify(Node(add, (1, x)), s) == Node(add, (1, 2)) + assert reify(NewNode(NewAdd, (1, x)), s) == NewNode(NewAdd, (1, 2)) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..e69c6a1 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,60 @@ +from kanren.term import arguments, operator + + +class Node(object): + def __init__(self, op, args): + self.op = op + self.args = args + + def __eq__(self, other): + return ( + type(self) == type(other) + and self.op == other.op + and self.args == other.args + ) + + def __hash__(self): + return hash((type(self), self.op, self.args)) + + def __str__(self): + return "%s(%s)" % (self.op.name, ", ".join(map(str, self.args))) + + __repr__ = __str__ + + +class Operator(object): + def __init__(self, name): + self.name = name + + def __call__(self, *args): + return Node(self, args) + + def __eq__(self, other): + return type(self) == type(other) and self.name == other.name + + def __hash__(self): + return hash((type(self), self.name)) + + def __str__(self): + return self.name + + __repr__ = __str__ + + +Add = Operator("add") +Mul = Operator("mul") + + +@arguments.register(Node) +def arguments_Node(t): + return t.args + + +@operator.register(Node) +def operator_Node(t): + return t.op + + +# @term.register(Operator, Sequence) +# def term_Operator(op, args): +# return Node(op, args) From 47fac72105a1eb568f73e12af7aa780fc2d12052 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:41 -0600 Subject: [PATCH 03/14] Consolidate and make mapo functions variadic --- kanren/graph.py | 123 ++++++++++++++++++++++++++++-------------------- 1 file changed, 71 insertions(+), 52 deletions(-) diff --git a/kanren/graph.py b/kanren/graph.py index 3fad195..2a8a21e 100644 --- a/kanren/graph.py +++ b/kanren/graph.py @@ -8,75 +8,94 @@ from .term import applyo -def mapo(relation, a, b, null_type=list, null_res=True, first=True): - """Apply a relation to corresponding elements in two sequences and succeed if the relation succeeds for all pairs.""" # noqa: E501 +def mapo(*args, null_res=True, **kwargs): + """Apply a relation to corresponding elements in two sequences and succeed if the relation succeeds for all sets of elements. # noqa: E501 - b_car, b_cdr = var(), var() - a_car, a_cdr = var(), var() - - return conde( - [nullo(a, b, default_ConsNull=null_type) if (not first or null_res) else fail], - [ - conso(a_car, a_cdr, a), - conso(b_car, b_cdr, b), - Zzz(relation, a_car, b_car), - Zzz(mapo, relation, a_cdr, b_cdr, null_type=null_type, first=False), - ], - ) + See `map_anyo` for parameter descriptions. + """ + return map_anyo(*args, null_res=null_res, _first=True, _any_succeed=None, **kwargs) def map_anyo( - relation, a, b, null_type=list, null_res=False, first=True, any_succeed=False + relation, + *args, + null_type=list, + null_res=False, + _first=True, + _any_succeed=False, + **kwargs ): - """Apply a relation to corresponding elements in two sequences and succeed if at least one pair succeeds. + """Apply a relation to corresponding elements across sequences and succeed if at least one set of elements succeeds. Parameters ---------- + relation: Callable + The goal to apply across elements (`car`s, specifically) of `args`. + *args: Sequence + Argument list containing terms that are walked and evaluated as + `relation(car(a_1), car(a_2), ...)`. null_type: optional An object that's a valid cdr for the collection type desired. If `False` (i.e. the default value), the cdr will be inferred from the inputs, or defaults to an empty list. + null_res: bool + Succeed on empty lists. + _first: bool + Indicate whether or not this is the first iteration in a call to this + goal constructor (in contrast to a recursive call). + This is not a user-level parameter. + _any_succeed: bool or None + Indicate whether or not an iteration has succeeded in a recursive call + to this goal, or, if `None`, indicate that only the relation against the + `cars` should be checked (i.e. no "any" functionality). + This is not a user-level parameter. + **kwargs: dict + Keyword arguments to `relation`. """ # noqa: E501 - b_car, b_cdr = var(), var() - a_car, a_cdr = var(), var() + cars = tuple(var() for a in args) + cdrs = tuple(var() for a in args) - return conde( - [ - nullo(a, b, default_ConsNull=null_type) - if (any_succeed or (first and null_res)) - else fail - ], + conde_branches = [ [ - conso(a_car, a_cdr, a), - conso(b_car, b_cdr, b), - conde( - [ - Zzz(relation, a_car, b_car), - Zzz( - map_anyo, - relation, - a_cdr, - b_cdr, - null_type=null_type, - any_succeed=True, - first=False, - ), - ], - [ - eq(a_car, b_car), - Zzz( - map_anyo, - relation, - a_cdr, - b_cdr, - null_type=null_type, - any_succeed=any_succeed, - first=False, - ), - ], + Zzz(relation, *cars, **kwargs), + Zzz( + map_anyo, + relation, + *cdrs, + null_type=null_type, + null_res=null_res, + _first=False, + _any_succeed=True if _any_succeed is not None else None, + **kwargs ), - ], + ] + ] + + if _any_succeed is not None: + nullo_condition = _any_succeed or (_first and null_res) + conde_branches.append( + [eq(a_car, b_car) for a_car, b_car in zip(cars, cars[1:])] + + [ + Zzz( + map_anyo, + relation, + *cdrs, + null_type=null_type, + null_res=null_res, + _first=False, + _any_succeed=_any_succeed, + **kwargs + ), + ] + ) + else: + nullo_condition = not _first or null_res + + return conde( + [nullo(*args, default_ConsNull=null_type) if nullo_condition else fail], + [conso(car, cdr, arg) for car, cdr, arg in zip(cars, cdrs, args)] + + [conde(*conde_branches)], ) From 1f37f65beb34684d5a1bb57aa74f24910772009f Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:43 -0600 Subject: [PATCH 04/14] Simplify walko --- kanren/graph.py | 47 +++++++++++++++++------------------------------ 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/kanren/graph.py b/kanren/graph.py index 2a8a21e..4802289 100644 --- a/kanren/graph.py +++ b/kanren/graph.py @@ -197,38 +197,25 @@ def walko( The map relation used to apply `goal` to a sub-graph. """ - def walko_goal(s): + rator_in, rands_in, rator_out, rands_out = var(), var(), var(), var() - nonlocal goal, rator_goal, graph_in, graph_out, null_type, map_rel - - graph_in_rf, graph_out_rf = reify((graph_in, graph_out), s) - - rator_in, rands_in, rator_out, rands_out = var(), var(), var(), var() - - _walko = partial( - walko, goal, rator_goal=rator_goal, null_type=null_type, map_rel=map_rel - ) - - g = conde( - # TODO: Use `Zzz`, if needed. - [ - goal(graph_in_rf, graph_out_rf), - ], - [ - lall( - applyo(rator_in, rands_in, graph_in_rf), - applyo(rator_out, rands_out, graph_out_rf), - rator_goal(rator_in, rator_out), - map_rel(_walko, rands_in, rands_out, null_type=null_type), - ) - if rator_goal is not None - else map_rel(_walko, graph_in_rf, graph_out_rf, null_type=null_type), - ], - ) - - yield from g(s) + _walko = partial( + walko, goal, rator_goal=rator_goal, null_type=null_type, map_rel=map_rel + ) - return walko_goal + return conde( + [ + Zzz(goal, graph_in, graph_out), + ], + [ + applyo(rator_in, rands_in, graph_in), + applyo(rator_out, rands_out, graph_out), + Zzz(rator_goal, rator_in, rator_out), + Zzz(map_rel, _walko, rands_in, rands_out, null_type=null_type), + ] + if rator_goal is not None + else [Zzz(map_rel, _walko, graph_in, graph_out, null_type=null_type)], + ) def term_walko( From 1cffe75795e8b3fa6410febaf615d20773633af0 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:44 -0600 Subject: [PATCH 05/14] Introduce abstract term type --- kanren/term.py | 20 ++++++++++++++++++++ tests/test_term.py | 12 +++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/kanren/term.py b/kanren/term.py index 6c6a094..6f0f210 100644 --- a/kanren/term.py +++ b/kanren/term.py @@ -1,3 +1,4 @@ +from abc import ABCMeta from collections.abc import Mapping, Sequence from cons.core import ConsError, cons @@ -11,6 +12,25 @@ from .goals import conso +class TermMetaType(ABCMeta): + """A meta type that can be used to check for `operator`/`arguments` support.""" + + def __instancecheck__(self, o): + o_type = type(o) + if any(issubclass(o_type, k) for k in operator.funcs.keys() if k[0] != object): + return True + return False + + def __subclasscheck__(self, o_type): + if any(issubclass(o_type, k) for k in operator.funcs.keys() if k[0] != object): + return True + return False + + +class TermType(metaclass=TermMetaType): + pass + + def applyo(o_rator, o_rands, obj): """Construct a goal that relates an object to the application of its (ope)rator to its (ope)rands. diff --git a/tests/test_term.py b/tests/test_term.py index 8c79203..05faaf7 100644 --- a/tests/test_term.py +++ b/tests/test_term.py @@ -1,9 +1,10 @@ from cons import cons +from cons.core import ConsType from etuples import etuple from unification import reify, unify, var from kanren.core import run -from kanren.term import applyo, arguments, operator, term +from kanren.term import TermType, applyo, arguments, operator, term from tests.utils import Add, Node, Operator @@ -77,3 +78,12 @@ def __call__(self, *args): assert s == {x: 2} assert reify(NewNode(NewAdd, (1, x)), s) == NewNode(NewAdd, (1, 2)) + + +def test_TermType(): + assert issubclass(type(Add(1, 2)), TermType) + assert isinstance(Add(1, 2), TermType) + assert not issubclass(type([1, 2]), TermType) + assert not isinstance([1, 2], TermType) + assert not isinstance(ConsType, TermType) + assert not issubclass(type(ConsType), TermType) From 1faae2e2476b88d2eec6698835afddbe0f33dd27 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:46 -0600 Subject: [PATCH 06/14] Introduce a shallow ground ordering generic function This comes along with a warning/clarification to the assoccomm relations. They have always walked the first argument, which, when used with the old `ground_order` would only account for improperly ordered terms when the second argument was fully ground (and the first wasn't). However, it would perform this check repeatedly, incurring a full-traversal cost on every iteration. This inefficiency has been rectified. The new "shallow" ground order should accomplish the same--and more--by reordering according to "groundness" only at next level of traversal. The assoccomm goals need to be updated so that they properly use this shallow ordering, though. --- kanren/assoccomm.py | 35 ++++++++++----------------- kanren/core.py | 38 +++++++++++++++++------------ kanren/graph.py | 10 +++----- kanren/term.py | 7 +++++- tests/test_assoccomm.py | 53 +++++++++++++++++++++-------------------- tests/test_core.py | 18 ++++++++++++++ tests/test_term.py | 15 +++++++++++- 7 files changed, 103 insertions(+), 73 deletions(-) diff --git a/kanren/assoccomm.py b/kanren/assoccomm.py index 99edc58..fdd7c75 100644 --- a/kanren/assoccomm.py +++ b/kanren/assoccomm.py @@ -4,29 +4,6 @@ accomplishes this through naively trying all possibilities. This was built to be used in the computer algebra systems SymPy and Theano. ->>> from kanren import run, var, fact ->>> from kanren.assoccomm import eq_assoccomm as eq ->>> from kanren.assoccomm import commutative, associative - ->>> # Define some dummy Ops ->>> add = 'add' ->>> mul = 'mul' - ->>> # Declare that these ops are commutative using the facts system ->>> fact(commutative, mul) ->>> fact(commutative, add) ->>> fact(associative, mul) ->>> fact(associative, add) - ->>> # Define some wild variables ->>> x, y = var('x'), var('y') - ->>> # Two expressions to match ->>> pattern = (mul, (add, 1, x), y) # (1 + x) * y ->>> expr = (mul, 2, (add, 3, 1)) # 2 * (3 + 1) - ->>> print(run(0, (x,y), eq(pattern, expr))) -((3, 2),) """ from collections.abc import Sequence from functools import partial @@ -176,6 +153,10 @@ def eq_assoc_args_goal(S): def eq_assoc(u, v, n=None, op_predicate=associative, null_type=etuple): """Create a goal for associative unification of two terms. + Warning: This goal walks the left-hand argument, `u`, so make that argument + the most ground term; otherwise, it may iterate indefinitely when it should + actually terminate. + >>> from kanren import run, var, fact >>> from kanren.assoccomm import eq_assoc as eq @@ -196,6 +177,10 @@ def assoc_args_unique(a, b, op, **kwargs): def eq_comm(u, v, op_predicate=commutative, null_type=etuple): """Create a goal for commutative equality. + Warning: This goal walks the left-hand argument, `u`, so make that argument + the most ground term; otherwise, it may iterate indefinitely when it should + actually terminate. + >>> from kanren import run, var, fact >>> from kanren.assoccomm import eq_comm as eq >>> from kanren.assoccomm import commutative, associative @@ -239,6 +224,10 @@ def op_pred(sub_op): def eq_assoccomm(u, v, null_type=etuple): """Construct a goal for associative and commutative unification. + Warning: This goal walks the left-hand argument, `u`, so make that argument + the most ground term; otherwise, it may iterate indefinitely when it should + actually terminate. + >>> from kanren.assoccomm import eq_assoccomm as eq >>> from kanren.assoccomm import commutative, associative >>> from kanren import fact, run, var diff --git a/kanren/core.py b/kanren/core.py index 3da6b60..e22e159 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -1,12 +1,13 @@ -from collections.abc import Generator, Sequence +from collections.abc import Generator, Mapping, Sequence from functools import partial, reduce from itertools import tee from operator import length_hint +from statistics import mean -from cons.core import ConsPair +from cons.core import ConsPair, car, cdr +from multipledispatch import dispatch from toolz import interleave, take from unification import isvar, reify, unify -from unification.core import isground def fail(s): @@ -113,18 +114,25 @@ def conde(*goals): lany = ldisj -def ground_order_key(S, x): +@dispatch(Mapping, object) +def shallow_ground_order_key(S, x): if isvar(x): - return 2 - elif isground(x, S): - return -1 - elif issubclass(type(x), ConsPair): - return 1 - else: - return 0 - - -def ground_order(in_args, out_args): + return 10 + elif isinstance(x, ConsPair): + val = 0 + val += 1 if isvar(car(x)) else 0 + cdr_x = cdr(x) + if issubclass(type(x), ConsPair): + val += 2 if isvar(cdr_x) else 0 + elif len(cdr_x) == 1: + val += 1 if isvar(cdr_x[0]) else 0 + elif len(cdr_x) > 1: + val += mean(1.0 if isvar(i) else 0.0 for i in cdr_x) + return val + return 0 + + +def ground_order(in_args, out_args, key_fn=shallow_ground_order_key): """Construct a non-relational goal that orders a list of terms based on groundedness (grounded precede ungrounded).""" # noqa: E501 def ground_order_goal(S): @@ -134,7 +142,7 @@ def ground_order_goal(S): S_new = unify( list(out_args_rf) if isinstance(out_args_rf, Sequence) else out_args_rf, - sorted(in_args_rf, key=partial(ground_order_key, S)), + sorted(in_args_rf, key=partial(key_fn, S)), S, ) diff --git a/kanren/graph.py b/kanren/graph.py index 4802289..eec5ea3 100644 --- a/kanren/graph.py +++ b/kanren/graph.py @@ -3,7 +3,7 @@ from etuples import etuple from unification import isvar, reify, var -from .core import Zzz, conde, eq, fail, ground_order, lall, succeed +from .core import Zzz, conde, eq, fail, lall, succeed from .goals import conso, nullo from .term import applyo @@ -239,13 +239,11 @@ def term_walko( should always fail! """ - def single_step(s, t): - u, v = var(), var() + def single_step(u, v): u_rator, u_rands = var(), var() v_rands = var() return lall( - ground_order((s, t), (u, v)), applyo(u_rator, u_rands, u), applyo(u_rator, v_rands, v), rator_goal(u_rator), @@ -256,13 +254,11 @@ def single_step(s, t): Zzz(rands_goal, u_rands, v_rands, u_rator, **kwargs), ) - def term_walko_step(s, t): + def term_walko_step(u, v): nonlocal rator_goal, rands_goal, null_type - u, v = var(), var() z, w = var(), var() return lall( - ground_order((s, t), (u, v)), format_step(u, w) if format_step is not None else eq(u, w), conde( [ diff --git a/kanren/term.py b/kanren/term.py index 6f0f210..d7fedfb 100644 --- a/kanren/term.py +++ b/kanren/term.py @@ -8,7 +8,7 @@ from unification.core import _reify, _unify, construction_sentinel, reify from unification.variable import isvar -from .core import eq, lall +from .core import eq, lall, shallow_ground_order_key from .goals import conso @@ -104,3 +104,8 @@ def unify_term(u, v, s): if s is not False: s = yield _unify(u_args, v_args, s) yield s + + +@shallow_ground_order_key.register(Mapping, TermType) +def shallow_ground_order_key_TermType(S, x): + return shallow_ground_order_key(S, cons(operator(x), arguments(x))) diff --git a/tests/test_assoccomm.py b/tests/test_assoccomm.py index 90dddd1..372e1a6 100644 --- a/tests/test_assoccomm.py +++ b/tests/test_assoccomm.py @@ -39,8 +39,8 @@ def test_eq_comm(): assert run(0, True, eq_comm((comm_op, 1, 2, 3), (comm_op, 1, 2, 3))) == (True,) assert run(0, True, eq_comm((comm_op, 3, 2, 1), (comm_op, 1, 2, 3))) == (True,) - assert run(0, y, eq_comm((comm_op, 3, y, 1), (comm_op, 1, 2, 3))) == (2,) - assert run(0, (x, y), eq_comm((comm_op, x, y, 1), (comm_op, 1, 2, 3))) == ( + assert run(0, y, eq_comm((comm_op, 1, 2, 3), (comm_op, 3, y, 1))) == (2,) + assert run(0, (x, y), eq_comm((comm_op, 1, 2, 3), (comm_op, x, y, 1))) == ( (2, 3), (3, 2), ) @@ -86,9 +86,9 @@ def test_eq_comm(): assert expected_res == set( run(0, (x, y, z), eq_comm((comm_op, 1, 2, 3), (comm_op, x, y, z))) ) - assert expected_res == set( - run(0, (x, y, z), eq_comm((comm_op, x, y, z), (comm_op, 1, 2, 3))) - ) + # assert expected_res == set( + # run(0, (x, y, z), eq_comm((comm_op, x, y, z), (comm_op, 1, 2, 3))) + # ) assert expected_res == set( run( 0, @@ -97,23 +97,20 @@ def test_eq_comm(): ) ) - e1 = (comm_op, (comm_op, 1, x), y) - e2 = (comm_op, 2, (comm_op, 3, 1)) + e1 = (comm_op, 2, (comm_op, 3, 1)) + e2 = (comm_op, (comm_op, 1, x), y) assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) e1 = ((comm_op, 3, 1),) e2 = ((comm_op, 1, x),) - assert run(0, x, eq_comm(e1, e2)) == (3,) e1 = (2, (comm_op, 3, 1)) e2 = (y, (comm_op, 1, x)) - assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) - e1 = (comm_op, (comm_op, 1, x), y) - e2 = (comm_op, 2, (comm_op, 3, 1)) - + e1 = (comm_op, 2, (comm_op, 3, 1)) + e2 = (comm_op, (comm_op, 1, x), y) assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) @@ -296,12 +293,12 @@ def test_eq_assoc(): (assoc_op, 1, (assoc_op, 2, 3)), ) - res = run(0, x, eq_assoc(x, (assoc_op, 1, 2, 3), n=2)) - assert res == ( - (assoc_op, (assoc_op, 1, 2), 3), - (assoc_op, 1, 2, 3), - (assoc_op, 1, (assoc_op, 2, 3)), - ) + # res = run(0, x, eq_assoc(x, (assoc_op, 1, 2, 3), n=2)) + # assert res == ( + # (assoc_op, (assoc_op, 1, 2), 3), + # (assoc_op, 1, 2, 3), + # (assoc_op, 1, (assoc_op, 2, 3)), + # ) y, z = var(), var() @@ -322,7 +319,7 @@ def test_eq_assoc(): assert all(isvar(i) for i in reify((x, y, z), s)) # Make sure it works with `cons` - res = run(0, (x, y), eq_assoc(cons(x, y), (assoc_op, 1, 2, 3))) + res = run(0, (x, y), eq_assoc((assoc_op, 1, 2, 3), cons(x, y))) assert res == ( (assoc_op, ((assoc_op, 1, 2), 3)), (assoc_op, (1, 2, 3)), @@ -337,8 +334,8 @@ def test_eq_assoc(): # run(1, (x, y), eq_assoc(cons(x, y), (x, z), op_predicate=associative_2)) # Nested expressions should work now - expr1 = (assoc_op, 1, 2, (assoc_op, x, 5, 6)) - expr2 = (assoc_op, (assoc_op, 1, 2), 3, 4, 5, 6) + expr1 = (assoc_op, (assoc_op, 1, 2), 3, 4, 5, 6) + expr2 = (assoc_op, 1, 2, (assoc_op, x, 5, 6)) assert run(0, x, eq_assoc(expr1, expr2, n=2)) == ((assoc_op, 3, 4),) @@ -402,7 +399,7 @@ def test_eq_assoccomm(): assert run(0, True, eq_assoccomm(1, 1)) == (True,) assert run(0, True, eq_assoccomm((1,), (1,))) == (True,) - assert run(0, True, eq_assoccomm(x, (1,))) == (True,) + # assert run(0, True, eq_assoccomm(x, (1,))) == (True,) assert run(0, True, eq_assoccomm((1,), x)) == (True,) # Assoc only @@ -444,12 +441,16 @@ def test_eq_assoccomm(): assert set(run(0, x, eq_assoccomm((ac, 1, 3, 2), x))) == exp_res assert set(run(0, x, eq_assoccomm((ac, 2, (ac, 3, 1)), x))) == exp_res # LHS variations - assert set(run(0, x, eq_assoccomm(x, (ac, 1, (ac, 2, 3))))) == exp_res + # assert set(run(0, x, eq_assoccomm(x, (ac, 1, (ac, 2, 3))))) == exp_res - assert run(0, (x, y), eq_assoccomm((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) == ( + assert run(0, (x, y), eq_assoccomm((ac, 2, (ac, 3, 1)), (ac, (ac, 1, x), y))) == ( (2, 3), (3, 2), ) + # assert run(0, (x, y), eq_assoccomm((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) == ( + # (2, 3), + # (3, 2), + # ) assert run(0, True, eq_assoccomm((ac, (ac, 1, 2), 3), (ac, 1, 2, 3))) == (True,) assert run(0, True, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, 2, 3))) == (True,) @@ -502,8 +503,8 @@ def test_assoccomm_algebra(): x, y = var(), var() - pattern = (mul, (add, 1, x), y) # (1 + x) * y - expr = (mul, 2, (add, 3, 1)) # 2 * (3 + 1) + pattern = (mul, 2, (add, 3, 1)) # 2 * (3 + 1) + expr = (mul, (add, 1, x), y) # (1 + x) * y assert run(0, (x, y), eq_assoccomm(pattern, expr)) == ((3, 2),) diff --git a/tests/test_core.py b/tests/test_core.py index b1a761e..7b98a03 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -18,6 +18,7 @@ ldisj, ldisj_seq, run, + shallow_ground_order_key, succeed, ) @@ -248,11 +249,28 @@ def test_ifa(): def test_ground_order(): x, y, z = var(), var(), var() + + assert shallow_ground_order_key({}, x) > shallow_ground_order_key({}, (x, y)) + assert shallow_ground_order_key({}, cons(x, y)) > shallow_ground_order_key( + {}, (x, y) + ) + assert shallow_ground_order_key({}, cons(1, 2)) < shallow_ground_order_key( + {}, (1, 2, 3, 4, y) + ) + assert shallow_ground_order_key({}, cons(1, 2)) == shallow_ground_order_key( + {}, (1, 2, 3, 4) + ) + assert shallow_ground_order_key({}, (x, y)) == shallow_ground_order_key( + {}, (x, y, z) + ) + assert run(0, x, ground_order((y, [1, z], 1), x)) == ([1, [1, z], y],) + a, b, c = var(), var(), var() assert run(0, (a, b, c), ground_order((y, [1, z], 1), (a, b, c))) == ( (1, [1, z], y), ) + res = run(0, z, ground_order([cons(x, y), (x, y)], z)) assert res == ([(x, y), cons(x, y)],) res = run(0, z, ground_order([(x, y), cons(x, y)], z)) diff --git a/tests/test_term.py b/tests/test_term.py index 05faaf7..9c8a1c4 100644 --- a/tests/test_term.py +++ b/tests/test_term.py @@ -3,7 +3,7 @@ from etuples import etuple from unification import reify, unify, var -from kanren.core import run +from kanren.core import run, shallow_ground_order_key from kanren.term import TermType, applyo, arguments, operator, term from tests.utils import Add, Node, Operator @@ -87,3 +87,16 @@ def test_TermType(): assert not isinstance([1, 2], TermType) assert not isinstance(ConsType, TermType) assert not issubclass(type(ConsType), TermType) + + +def test_shallow_ground_order(): + + x, y, z = var(), var(), var() + + assert shallow_ground_order_key({}, x) > shallow_ground_order_key({}, Add(x, y)) + assert shallow_ground_order_key({}, cons(x, y)) > shallow_ground_order_key( + {}, Add(x, y) + ) + assert shallow_ground_order_key({}, Add(x, y)) == shallow_ground_order_key( + {}, Add(x, y, z) + ) From 5a848b22f85b73e176920eb89d9049578987622b Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:47 -0600 Subject: [PATCH 07/14] Add a groundedness ordering for use with two sequences --- kanren/core.py | 38 ++++++++++++++++++++++++++++++++++++++ tests/test_core.py | 23 +++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/kanren/core.py b/kanren/core.py index e22e159..1d60118 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -152,6 +152,44 @@ def ground_order_goal(S): return ground_order_goal +def ground_order_seqs(in_seqs, out_seqs, key_fn=shallow_ground_order_key): + """Construct a non-relational goal that orders lists of sequences based on the groundedness of their corresponding terms. # noqa: E501 + + >>> from unification import var + >>> x, y = var('x'), var('y') + >>> a, b = var('a'), var('b') + >>> run(0, (x, y), ground_order_seqs([(a, b), (b, 2)], [x, y])) + (((~b, ~a), (2, ~b)),) + """ + + def ground_order_seqs_goal(S): + nonlocal in_seqs, out_seqs, key_fn + + in_seqs_rf, out_seqs_rf = reify((in_seqs, out_seqs), S) + + if ( + not any(isinstance(s, str) for s in in_seqs_rf) + and reduce( + lambda x, y: x == y and y, (length_hint(s, -1) for s in in_seqs_rf) + ) + > 0 + ): + + in_seqs_ord = zip(*sorted(zip(*in_seqs_rf), key=partial(key_fn, S))) + S_new = unify(list(out_seqs_rf), list(in_seqs_ord), S) + + if S_new is not False: + yield S_new + else: + + S_new = unify(out_seqs_rf, in_seqs_rf, S) + + if S_new is not False: + yield S_new + + return ground_order_seqs_goal + + def ifa(g1, g2): """Create a goal operator that returns the first stream unless it fails.""" diff --git a/tests/test_core.py b/tests/test_core.py index 7b98a03..8444c44 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -10,6 +10,7 @@ eq, fail, ground_order, + ground_order_seqs, ifa, lall, lany, @@ -275,3 +276,25 @@ def test_ground_order(): assert res == ([(x, y), cons(x, y)],) res = run(0, z, ground_order([(x, y), cons(x, y)], z)) assert res == ([(x, y), cons(x, y)],) + + +def test_ground_order_seq(): + + x, y, z = var(), var(), var() + a, b = var(), var() + res = run(0, (x, y), ground_order_seqs([a, (b, 2)], [x, y])) + assert res == ((a, (b, 2)),) + res = run(0, (x, y), ground_order_seqs([(a,), (b, 2)], [x, y])) + assert res == (((a,), (b, 2)),) + res = run(0, (x, y), ground_order_seqs([(a, b), (b, 2)], [x, y])) + assert res == (((b, a), (2, b)),) + res = run(0, (x, y), ground_order_seqs([(b, 2), (a, b)], [x, y])) + assert res == (((2, b), (b, a)),) + res = run(0, (x, y, z), ground_order_seqs([(b, 2), (a, b), (0, 1)], [x, y, z])) + assert res == (((2, b), (b, a), (1, 0)),) + res = run(0, (x, y), ground_order_seqs([(), ()], [x, y])) + assert res == (((), ()),) + res = run(0, (x, y), ground_order_seqs([(a, (1, b)), (b, 2)], [x, y])) + assert res == ((((1, b), a), (2, b)),) + res = run(0, (x, y), ground_order_seqs(["abc", "def"], [x, y])) + assert res == (("abc", "def"),) From ade82365e8162d565416f04ad39d53a187aa5a6e Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:48 -0600 Subject: [PATCH 08/14] Use shallow associative flattening --- kanren/assoccomm.py | 7 +++++-- tests/test_assoccomm.py | 24 +++++++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/kanren/assoccomm.py b/kanren/assoccomm.py index fdd7c75..1e34baf 100644 --- a/kanren/assoccomm.py +++ b/kanren/assoccomm.py @@ -25,12 +25,15 @@ commutative = Relation("commutative") -def flatten_assoc_args(op_predicate, items): +def flatten_assoc_args(op_predicate, items, shallow=True): for i in items: if isinstance(i, ConsPair) and op_predicate(car(i)): i_cdr = cdr(i) if length_hint(i_cdr) > 0: - yield from flatten_assoc_args(op_predicate, i_cdr) + if shallow: + yield from iter(i_cdr) + else: + yield from flatten_assoc_args(op_predicate, i_cdr) else: yield i else: diff --git a/tests/test_assoccomm.py b/tests/test_assoccomm.py index 372e1a6..50c9a2f 100644 --- a/tests/test_assoccomm.py +++ b/tests/test_assoccomm.py @@ -140,12 +140,34 @@ def op_pred(x): res = list( flatten_assoc_args( - op_pred, [[1, 2, op], 3, [op, 4, [op, [op]]], [op, 5], 6, op, 7] + op_pred, + [[1, 2, op], 3, [op, 4, [op, [op]]], [op, 5], 6, op, 7], + shallow=False, ) ) exp_res = [[1, 2, op], 3, 4, [op], 5, 6, op, 7] assert res == exp_res + exa_col = (1, 2, ("b", 3, ("a", 4, 5)), ("c", 6, 7), ("a", ("a", 8, 9), 10)) + assert list(flatten_assoc_args(lambda x: x == "a", exa_col, shallow=False)) == [ + 1, + 2, + ("b", 3, ("a", 4, 5)), + ("c", 6, 7), + 8, + 9, + 10, + ] + + assert list(flatten_assoc_args(lambda x: x == "a", exa_col, shallow=True)) == [ + 1, + 2, + ("b", 3, ("a", 4, 5)), + ("c", 6, 7), + ("a", 8, 9), + 10, + ] + def test_assoc_args(): op = "add" From 777e25599e88d8e74f35e7beca110f793c3c2cd2 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:49 -0600 Subject: [PATCH 09/14] Improve flatten_assoc_args and make assoc_flatten relational --- kanren/assoccomm.py | 250 +++++++++++++++------ kanren/term.py | 34 ++- requirements.txt | 1 + setup.py | 2 +- tests/test_assoccomm.py | 477 ++++++++++++++++++++++++++++++++-------- tests/test_term.py | 6 +- tests/utils.py | 28 +++ 7 files changed, 629 insertions(+), 169 deletions(-) diff --git a/kanren/assoccomm.py b/kanren/assoccomm.py index 1e34baf..a877203 100644 --- a/kanren/assoccomm.py +++ b/kanren/assoccomm.py @@ -10,60 +10,205 @@ from operator import eq as equal from operator import length_hint -from cons.core import ConsPair, car, cdr from etuples import etuple -from toolz import sliding_window from unification import reify, unify, var -from .core import conde, eq, ground_order, lall, succeed +from .core import Zzz, conde, eq, fail, ground_order, lall, lany, succeed from .facts import Relation -from .goals import itero, permuteo +from .goals import conso, itero, permuteo from .graph import term_walko -from .term import term +from .term import TermType, applyo, arguments, operator, term associative = Relation("associative") commutative = Relation("commutative") -def flatten_assoc_args(op_predicate, items, shallow=True): - for i in items: - if isinstance(i, ConsPair) and op_predicate(car(i)): - i_cdr = cdr(i) - if length_hint(i_cdr) > 0: - if shallow: - yield from iter(i_cdr) +def flatten_assoc_args(op_predicate, term, shallow=True): + """Flatten/normalize a term with an associative operator. + + Parameters + ---------- + op_predicate: callable + A function used to determine the operators to flatten. + items: Sequence + The term to flatten. + shallow: bool (optional) + Indicate whether or not flattening should be done at all depths. + """ + + if not isinstance(term, TermType): + return term + + def _flatten(term): + for i in term: + if isinstance(i, TermType) and op_predicate(operator(i)): + i_cdr = arguments(i) + if length_hint(i_cdr) > 0: + if shallow: + yield from iter(i_cdr) + else: + yield from _flatten(i_cdr) else: - yield from flatten_assoc_args(op_predicate, i_cdr) + yield i else: yield i - else: - yield i + term_type = type(arguments(term)) + return term_type(_flatten(term)) -def assoc_args(rator, rands, n, ctor=None): - """Produce all associative argument combinations of rator + rands in n-sized rand groupings. - >>> from kanren.assoccomm import assoc_args - >>> list(assoc_args('op', [1, 2, 3], 2)) - [[['op', 1, 2], 3], [1, ['op', 2, 3]]] - """ # noqa: E501 - assert n > 0 +def partitions(in_seq, n_parts=None, min_size=1, part_fn=lambda x: x): + """Generate all partitions of a sequence for given numbers of partitions and minimum group sizes. # noqa: E501 + + Parameters + ---------- + in_seq: Sequence + The sequence to be partitioned. + n_parts: int + Number of partitions. `None` means all partitions in `range(2, len(in_seq))`. + min_size: int + The minimum size of a partition. + part_fn: Callable + A function applied to every partition. + """ + + def partition(seq, res): + if ( + n_parts is None + and + # We don't want the original sequence + len(res) > 0 + ) or len(res) + 1 == n_parts: + yield type(in_seq)(res + [part_fn(seq)]) + + if n_parts is not None: + return + + for s in range(min_size, len(seq) + 1 - min_size, 1): + yield from partition(seq[s:], res + [part_fn(seq[:s])]) + + return partition(in_seq, []) + + +def assoc_args(rator, rands, n=None, ctor=None): + """Produce all associative argument combinations of rator + rands in n-sized rand groupings. # noqa: E501 + + The normal/canonical form is left-associative, e.g. + `(op, 1, 2, 3, 4) == (op, (op, (op, 1, 2), 3), 4)` - rands_l = list(rands) + Parameters + ---------- + rator: object + The operator that's assumed to be associative. + rands: Sequence + The operands. + n: int (optional) + The number of arguments in the resulting `(op,) + output` terms. + If not specified, all argument sizes are returned. + ctor: callable + The constructor to use when each term is created. + If not specified, the constructor is inferred from `type(rands)`. + """ if ctor is None: ctor = type(rands) - if n == len(rands_l): + if len(rands) <= 2 or n is not None and len(rands) <= n: yield ctor(rands) return - for i, new_rands in enumerate(sliding_window(n, rands_l)): - prefix = rands_l[:i] - new_term = term(rator, ctor(new_rands)) - suffix = rands_l[n + i :] - res = ctor(prefix + [new_term] + suffix) - yield res + def part_fn(x): + if len(x) == 1: + return x[0] + else: + return term(rator, ctor(x)) + + for p in partitions(rands, n, 1, part_fn=part_fn): + yield ctor(p) + + +def assoc_flatteno(a, a_flat, no_ident=False, null_type=etuple): + """Construct a goal that flattens/normalizes terms with associative operators. + + The normal/canonical form is left-associative, e.g. + `(op, 1, 2, 3, 4) == (op, (op, (op, 1, 2), 3), 4)` + + Parameters + ---------- + a: Var or Sequence + The "input" term to flatten. + a_flat: Var or Sequence + The flattened result, or "output", term. + no_ident: bool (optional) + Whether or not to fail if no flattening occurs. + """ + + def assoc_flatteno_goal(S): + nonlocal a, a_flat + + a_rf, a_flat_rf = reify((a, a_flat), S) + + if isinstance(a_rf, TermType) and (operator(a_rf),) in associative.facts: + + a_op = operator(a_rf) + args_rf = arguments(a_rf) + + def op_pred(sub_op): + return sub_op == a_op + + a_flat_rf = term(a_op, flatten_assoc_args(op_pred, args_rf)) + + if a_flat_rf == a_rf and no_ident: + return + + yield from eq(a_flat, a_flat_rf)(S) + + elif ( + isinstance(a_flat_rf, TermType) + and (operator(a_flat_rf),) in associative.facts + ): + + a_flat_op = operator(a_flat_rf) + args_rf = arguments(a_flat_rf) + assoc_args_iter = assoc_args(a_flat_op, args_rf) + + # TODO: There are much better ways to do this `no_ident` check + # (e.g. the `n` argument of `assoc_args` should probably be made to + # work for this) + yield from lany( + fail if no_ident and r is args_rf else applyo(a_flat_op, r, a_rf) + for r in assoc_args_iter + )(S) + + else: + + op = var() + a_rands = var() + a_rands_rands = var() + a_flat_rands = var() + a_flat_rands_rands = var() + + g = conde( + [fail if no_ident else eq(a_rf, a_flat_rf)], + [ + associative(op), + applyo(op, a_rands, a_rf), + applyo(op, a_flat_rands, a_flat_rf), + # There must be at least two rands + conso(var(), a_rands_rands, a_rands), + conso(var(), var(), a_rands_rands), + conso(var(), a_flat_rands_rands, a_flat_rands), + conso(var(), var(), a_flat_rands_rands), + itero( + a_flat_rands, nullo_refs=(a_rands,), default_ConsNull=null_type + ), + Zzz(assoc_flatteno, a_rf, a_flat_rf, no_ident=no_ident), + ], + ) + + yield from g(S) + + return assoc_flatteno_goal def eq_assoc_args( @@ -98,15 +243,17 @@ def eq_assoc_args_goal(S): u_args_flat = type(u_args_rf)(flatten_assoc_args(op_pred, u_args_rf)) v_args_flat = type(v_args_rf)(flatten_assoc_args(op_pred, v_args_rf)) - if len(u_args_flat) == len(v_args_flat): + u_len, v_len = len(u_args_flat), len(v_args_flat) + if u_len == v_len: g = inner_eq(u_args_flat, v_args_flat) else: - if len(u_args_flat) < len(v_args_flat): + if u_len < v_len: sm_args, lg_args = u_args_flat, v_args_flat + grp_sizes = u_len else: sm_args, lg_args = v_args_flat, u_args_flat + grp_sizes = v_len - grp_sizes = len(lg_args) - len(sm_args) + 1 assoc_terms = assoc_args( op_rf, lg_args, grp_sizes, ctor=type(u_args_rf) ) @@ -129,20 +276,13 @@ def eq_assoc_args_goal(S): u_args_flat = list(flatten_assoc_args(partial(equal, op_rf), u_args_rf)) - if n_rf is not None: - arg_sizes = [n_rf] - else: - arg_sizes = range(2, len(u_args_flat) + (not no_ident)) - - v_ac_args = ( - v_ac_arg - for n_i in arg_sizes + g = conde( + [inner_eq(v_args_rf, v_ac_arg)] for v_ac_arg in assoc_args( - op_rf, u_args_flat, n_i, ctor=type(u_args_rf) + op_rf, u_args_flat, n_rf, ctor=type(u_args_rf) ) if not no_ident or v_ac_arg != u_args_rf ) - g = conde([inner_eq(v_args_rf, v_ac_arg)] for v_ac_arg in v_ac_args) yield from g(S) @@ -202,28 +342,6 @@ def permuteo_unique(x, y, op, **kwargs): return term_walko(op_predicate, permuteo_unique, u, v) -def assoc_flatten(a, a_flat): - def assoc_flatten_goal(S): - nonlocal a, a_flat - - a_rf = reify(a, S) - - if isinstance(a_rf, Sequence) and (a_rf[0],) in associative.facts: - - def op_pred(sub_op): - nonlocal S - sub_op_rf = reify(sub_op, S) - return sub_op_rf == a_rf[0] - - a_flat_rf = type(a_rf)(flatten_assoc_args(op_pred, a_rf)) - else: - a_flat_rf = a_rf - - yield from eq(a_flat, a_flat_rf)(S) - - return assoc_flatten_goal - - def eq_assoccomm(u, v, null_type=etuple): """Construct a goal for associative and commutative unification. @@ -267,6 +385,6 @@ def eq_assoccomm_step(a, b, op): eq_assoccomm_step, u, v, - format_step=assoc_flatten, + format_step=assoc_flatteno, no_ident=False, ) diff --git a/kanren/term.py b/kanren/term.py index d7fedfb..7cc0e3e 100644 --- a/kanren/term.py +++ b/kanren/term.py @@ -1,10 +1,13 @@ from abc import ABCMeta -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping +from operator import length_hint -from cons.core import ConsError, cons +from cons.core import ConsError, ProperSequence, cons +from etuples import apply from etuples import apply as term from etuples import rands as arguments from etuples import rator as operator +from multipledispatch import dispatch from unification.core import _reify, _unify, construction_sentinel, reify from unification.variable import isvar @@ -17,12 +20,19 @@ class TermMetaType(ABCMeta): def __instancecheck__(self, o): o_type = type(o) - if any(issubclass(o_type, k) for k in operator.funcs.keys() if k[0] != object): + if ( + isinstance(o, ProperSequence) + or any( + issubclass(o_type, k) for k in operator.funcs.keys() if k[0] != object + ) + ) and length_hint(o, 1) > 0: return True return False def __subclasscheck__(self, o_type): - if any(issubclass(o_type, k) for k in operator.funcs.keys() if k[0] != object): + if issubclass(o_type, ProperSequence) or any( + issubclass(o_type, k) for k in operator.funcs.keys() if k[0] != object + ): return True return False @@ -75,14 +85,19 @@ def applyo_goal(S): return applyo_goal -@term.register(object, Sequence) -def term_Sequence(rator, rands): +@dispatch(object, ProperSequence) +def term(rator, rands): # Overwrite the default `apply` dispatch function and make it preserve # types res = cons(rator, rands) return res +@term.register(Callable, ProperSequence) +def term_ExpressionTuple(rator, rands): + return apply(rator, rands) + + def unifiable_with_term(cls): _reify.add((cls, Mapping), reify_term) _unify.add((cls, cls, Mapping), unify_term) @@ -108,4 +123,9 @@ def unify_term(u, v, s): @shallow_ground_order_key.register(Mapping, TermType) def shallow_ground_order_key_TermType(S, x): - return shallow_ground_order_key(S, cons(operator(x), arguments(x))) + if length_hint(x, 1) > 0: + return shallow_ground_order_key.dispatch(type(S), object)( + S, cons(operator(x), arguments(x)) + ) + else: + return shallow_ground_order_key.dispatch(type(S), object)(S, x) diff --git a/requirements.txt b/requirements.txt index 87db56e..995ffd8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ pydocstyle>=3.0.0 pytest>=5.0.0 pytest-cov>=2.6.1 pytest-html>=1.20.0 +pytest-timeout pylint>=2.3.1 black>=19.3b0; platform.python_implementation!='PyPy' diff-cover diff --git a/setup.py b/setup.py index c7a77c6..06ee62d 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ packages=["kanren"], install_requires=[ "toolz", - "cons >= 0.4.0", + "cons >= 0.4.2", "multipledispatch", "etuples >= 0.3.1", "logical-unification >= 0.4.1", diff --git a/tests/test_assoccomm.py b/tests/test_assoccomm.py index 50c9a2f..b0fb644 100644 --- a/tests/test_assoccomm.py +++ b/tests/test_assoccomm.py @@ -1,3 +1,5 @@ +from itertools import chain + import pytest from cons import cons from etuples.core import etuple @@ -5,7 +7,7 @@ from kanren.assoccomm import ( assoc_args, - assoc_flatten, + assoc_flatteno, associative, commutative, eq_assoc, @@ -13,6 +15,7 @@ eq_assoccomm, eq_comm, flatten_assoc_args, + partitions, ) from kanren.core import run from kanren.facts import fact @@ -169,32 +172,116 @@ def op_pred(x): ] -def test_assoc_args(): - op = "add" +def test_partitions(): - def op_pred(x): - return x == op + assert list(partitions(("a",), 2, 2)) == [] + assert list(partitions(("a", "b"), 2, 2)) == [] + assert list(partitions(("a", "b"), 2, 1)) == [(("a",), ("b",))] + assert list(partitions(("a", "b"), 2, 1, part_fn=lambda x: ("op",) + x)) == [ + (("op", "a"), ("op", "b")) + ] - assert tuple(assoc_args(op, (1, 2, 3), 2)) == ( - ((op, 1, 2), 3), - (1, (op, 2, 3)), - ) - assert tuple(assoc_args(op, [1, 2, 3], 2)) == ( - [[op, 1, 2], 3], - [1, [op, 2, 3]], + exa_col = tuple("abcdefg") + + expected_res = [ + (("a", "b"), ("c", "d", "e", "f", "g")), + (("a", "b", "c"), ("d", "e", "f", "g")), + (("a", "b", "c", "d"), ("e", "f", "g")), + (("a", "b", "c", "d", "e"), ("f", "g")), + ] + + assert list(partitions(exa_col, 2, 2)) == expected_res + + expected_res = [ + (("a", "b"), ("c", "d"), ("e", "f", "g")), + (("a", "b"), ("c", "d", "e"), ("f", "g")), + (("a", "b", "c"), ("d", "e"), ("f", "g")), + ] + + assert list(partitions(exa_col, 3, 2)) == expected_res + + expected_res = sorted( + chain.from_iterable( + [partitions(exa_col, i, 2) for i in range(2, len(exa_col) + 1)] + ) ) - assert tuple(assoc_args(op, (1, 2, 3), 1)) == ( - ((op, 1), 2, 3), - (1, (op, 2), 3), - (1, 2, (op, 3)), + assert sorted(partitions(exa_col, None, 2)) == expected_res + + res = list( + partitions( + tuple(range(1, 5)), + None, + 1, + part_fn=lambda x: x[0] if len(x) == 1 else ("op",) + x, + ) ) - assert tuple(assoc_args(op, (1, 2, 3), 3)) == ((1, 2, 3),) + assert res == [ + (1, ("op", 2, 3, 4)), + (1, 2, ("op", 3, 4)), + (1, 2, 3, 4), + (1, ("op", 2, 3), 4), + (("op", 1, 2), ("op", 3, 4)), + (("op", 1, 2), 3, 4), + (("op", 1, 2, 3), 4), + ] - f_rands = flatten_assoc_args(op_pred, (1, (op, 2, 3))) - assert tuple(assoc_args(op, f_rands, 2, ctor=tuple)) == ( - ((op, 1, 2), 3), - (1, (op, 2, 3)), + res = list( + partitions( + tuple(range(1, 5)), + 2, + 1, + part_fn=lambda x: x[0] if len(x) == 1 else ("op",) + x, + ) ) + assert res == [ + (1, ("op", 2, 3, 4)), + (("op", 1, 2), ("op", 3, 4)), + (("op", 1, 2, 3), 4), + ] + + +def test_assoc_args(): + + res = list(assoc_args("op", tuple(range(1, 5)), None)) + assert res == [ + (1, ("op", 2, 3, 4)), + (1, 2, ("op", 3, 4)), + (1, 2, 3, 4), + (1, ("op", 2, 3), 4), + (("op", 1, 2), ("op", 3, 4)), + (("op", 1, 2), 3, 4), + (("op", 1, 2, 3), 4), + ] + + res = list(assoc_args("op", tuple(range(1, 5)), None, ctor=list)) + assert res == [ + [1, ["op", 2, 3, 4]], + [1, 2, ["op", 3, 4]], + [1, 2, 3, 4], + [1, ["op", 2, 3], 4], + [["op", 1, 2], ["op", 3, 4]], + [["op", 1, 2], 3, 4], + [["op", 1, 2, 3], 4], + ] + + res = list(assoc_args("op", tuple(range(1, 5)), 2)) + assert res == [ + (1, ("op", 2, 3, 4)), + (("op", 1, 2), ("op", 3, 4)), + (("op", 1, 2, 3), 4), + ] + + res = list(assoc_args("op", (1, 2), 1)) + assert res == [(1, 2)] + + res = list(assoc_args("op", (1, 2, 3), 4)) + assert res == [(1, 2, 3)] + + res = list(assoc_args("op", (1, 2, 3), 3)) + assert res == [(1, 2, 3)] + + res = list(assoc_args("op", [1, 2, 3], 3, ctor=tuple)) + assert res == [(1, 2, 3)] def test_eq_assoc_args(): @@ -237,18 +324,18 @@ def test_eq_assoc_args(): assert run(0, True, eq_assoc_args(assoc_op, (1, 1), ("other_op", 1, 1))) == () assert run(0, x, eq_assoc_args(assoc_op, (1, 2, 3), x, n=2)) == ( - ((assoc_op, 1, 2), 3), (1, (assoc_op, 2, 3)), + ((assoc_op, 1, 2), 3), ) assert run(0, x, eq_assoc_args(assoc_op, x, (1, 2, 3), n=2)) == ( - ((assoc_op, 1, 2), 3), (1, (assoc_op, 2, 3)), + ((assoc_op, 1, 2), 3), ) assert run(0, x, eq_assoc_args(assoc_op, (1, 2, 3), x)) == ( - ((assoc_op, 1, 2), 3), (1, (assoc_op, 2, 3)), (1, 2, 3), + ((assoc_op, 1, 2), 3), ) assert () not in run(0, x, eq_assoc_args(assoc_op, (), x, no_ident=True)) @@ -288,11 +375,11 @@ def test_eq_assoc_args(): def test_eq_assoc(): - assoc_op = "assoc_op" - associative.index.clear() associative.facts.clear() + assoc_op = "assoc_op" + fact(associative, assoc_op) assert run(0, True, eq_assoc(1, 1)) == (True,) @@ -307,14 +394,36 @@ def test_eq_assoc(): o = "op" assert not run(0, True, eq_assoc((o, 1, 2, 3), (o, (o, 1, 2), 3))) - x = var() + x, y = var(), var() + res = run(0, x, eq_assoc((assoc_op, 1, 2, 3), x, n=2)) assert res == ( - (assoc_op, (assoc_op, 1, 2), 3), - (assoc_op, 1, 2, 3), (assoc_op, 1, (assoc_op, 2, 3)), + (assoc_op, 1, 2, 3), + (assoc_op, (assoc_op, 1, 2), 3), + ) + + # Make sure it works with `cons` + res = run(0, (x, y), eq_assoc((assoc_op, 1, 2, 3), cons(x, y))) + assert sorted(res, key=str) == sorted( + [ + (assoc_op, (1, (assoc_op, 2, 3))), + (assoc_op, (1, 2, 3)), + (assoc_op, ((assoc_op, 1, 2), 3)), + ], + key=str, ) + # XXX: Don't use a predicate that can never succeed, e.g. + # associative_2 = Relation("associative_2") + # run(1, (x, y), eq_assoc(cons(x, y), (x, z), op_predicate=associative_2)) + + # Nested expressions should work now + expr1 = (assoc_op, (assoc_op, 1, 2), 3, 4, 5, 6) + expr2 = (assoc_op, 1, 2, (assoc_op, x, 5, 6)) + assert run(0, x, eq_assoc(expr1, expr2, n=2)) == ((assoc_op, 3, 4),) + + # TODO: Need groundedness ordering for this # res = run(0, x, eq_assoc(x, (assoc_op, 1, 2, 3), n=2)) # assert res == ( # (assoc_op, (assoc_op, 1, 2), 3), @@ -322,46 +431,88 @@ def test_eq_assoc(): # (assoc_op, 1, (assoc_op, 2, 3)), # ) - y, z = var(), var() + +@pytest.mark.xfail(strict=False) +def test_eq_assoc_cons(): + associative.index.clear() + associative.facts.clear() + + assoc_op = "assoc_op" + + fact(associative, assoc_op) + + x, y, z = var(), var(), var() + + res = run(1, (x, y), eq_assoc(cons(x, y), (x, z, 2, 3))) + assert res == ((assoc_op, (z, (assoc_op, 2, 3))),) + + +@pytest.mark.xfail(strict=False) +def test_eq_assoc_all_variations(): + + associative.index.clear() + associative.facts.clear() + + assoc_op = "assoc_op" + + fact(associative, assoc_op) + + x = var() + expected_res = { + # Normalized, our results are left-associative, i.e. + # (assoc_op, (assoc_op, (assoc_op, 1, 2), 3), 4), + # is equal to the following: + (assoc_op, 1, 2, 3, 4), + (assoc_op, (assoc_op, 1, 2), 3, 4), + (assoc_op, 1, (assoc_op, 2, 3), 4), + (assoc_op, 1, 2, (assoc_op, 3, 4)), + (assoc_op, (assoc_op, 1, 2, 3), 4), + (assoc_op, 1, (assoc_op, 2, 3, 4)), + (assoc_op, (assoc_op, 1, 2), (assoc_op, 3, 4)), + (assoc_op, (assoc_op, (assoc_op, 1, 2), 3), 4), + (assoc_op, (assoc_op, 1, (assoc_op, 2, 3)), 4), + (assoc_op, 1, (assoc_op, (assoc_op, 2, 3), 4)), + (assoc_op, 1, (assoc_op, 2, (assoc_op, 3, 4))), + } + res = run(0, x, eq_assoc((assoc_op, 1, 2, 3, 4), x)) + assert sorted(res, key=str) == sorted(expected_res, key=str) + + +def test_eq_assoc_unground(): + + associative.index.clear() + associative.facts.clear() + + assoc_op = "assoc_op" + + fact(associative, assoc_op) + + x, y = var(), var() + xx, yy, zz = var(), var(), var() # Check results when both arguments are variables res = run(3, (x, y), eq_assoc(x, y)) exp_res_form = ( - (etuple(assoc_op, x, y, z), etuple(assoc_op, etuple(assoc_op, x, y), z)), - (x, y), + (etuple(assoc_op, xx, yy, zz), etuple(assoc_op, xx, etuple(assoc_op, yy, zz))), + (xx, yy), ( - etuple(etuple(assoc_op, x, y, z)), - etuple(etuple(assoc_op, etuple(assoc_op, x, y), z)), + etuple(etuple(assoc_op, xx, yy, zz)), + etuple(etuple(assoc_op, xx, etuple(assoc_op, yy, zz))), ), ) for a, b in zip(res, exp_res_form): s = unify(a, b) assert s is not False, (a, b) - assert all(isvar(i) for i in reify((x, y, z), s)) + assert all(isvar(i) for i in reify((xx, yy, zz), s)) - # Make sure it works with `cons` - res = run(0, (x, y), eq_assoc((assoc_op, 1, 2, 3), cons(x, y))) - assert res == ( - (assoc_op, ((assoc_op, 1, 2), 3)), - (assoc_op, (1, 2, 3)), - (assoc_op, (1, (assoc_op, 2, 3))), - ) - - res = run(1, (x, y), eq_assoc(cons(x, y), (x, z, 2, 3))) - assert res == ((assoc_op, ((assoc_op, z, 2), 3)),) - - # Don't use a predicate that can never succeed, e.g. - # associative_2 = Relation("associative_2") - # run(1, (x, y), eq_assoc(cons(x, y), (x, z), op_predicate=associative_2)) - - # Nested expressions should work now - expr1 = (assoc_op, (assoc_op, 1, 2), 3, 4, 5, 6) - expr2 = (assoc_op, 1, 2, (assoc_op, x, 5, 6)) - assert run(0, x, eq_assoc(expr1, expr2, n=2)) == ((assoc_op, 3, 4),) +def test_assoc_flatteno(): -def test_assoc_flatten(): + commutative.index.clear() + commutative.facts.clear() + associative.index.clear() + associative.facts.clear() add = "add" mul = "mul" @@ -370,12 +521,13 @@ def test_assoc_flatten(): fact(associative, add) fact(commutative, mul) fact(associative, mul) + fact(associative, Add) assert ( run( 0, True, - assoc_flatten( + assoc_flatteno( (mul, 1, (add, 2, 3), (mul, 4, 5)), (mul, 1, (add, 2, 3), 4, 5) ), ) @@ -383,42 +535,117 @@ def test_assoc_flatten(): ) x = var() - assert ( - run( - 0, - x, - assoc_flatten((mul, 1, (add, 2, 3), (mul, 4, 5)), x), - ) - == ((mul, 1, (add, 2, 3), 4, 5),) + assert run(0, x, assoc_flatteno((mul, 1, (add, 2, 3), (mul, 4, 5)), x)) == ( + (mul, 1, (add, 2, 3), 4, 5), ) assert ( run( 0, True, - assoc_flatten( + assoc_flatteno( ("op", 1, (add, 2, 3), (mul, 4, 5)), ("op", 1, (add, 2, 3), (mul, 4, 5)) ), ) == (True,) ) - assert run(0, x, assoc_flatten(("op", 1, (add, 2, 3), (mul, 4, 5)), x)) == ( + assert run(0, x, assoc_flatteno(("op", 1, (add, 2, 3), (mul, 4, 5)), x)) == ( ("op", 1, (add, 2, 3), (mul, 4, 5)), ) + assert run( + 0, True, assoc_flatteno((add, 1, (add, 2, 3), (mul, 4, 5)), x, no_ident=True) + ) == (True,) + + assert ( + run( + 0, + True, + assoc_flatteno((add, 1, (mul, 2, 3), (mul, 4, 5)), x, no_ident=True), + ) + == () + ) + + assert run(0, True, assoc_flatteno((add, 1, 2, 3), x)) == (True,) + assert run(0, True, assoc_flatteno(Add(1, 2, 3), x)) == (True,) + assert run(0, True, assoc_flatteno((add, 1, 2, 3), x, no_ident=True)) == () + + res = run(0, x, assoc_flatteno(x, (add, 1, 2, 3), no_ident=True)) + assert sorted(res, key=str) == sorted( + [(add,) + r for r in assoc_args(add, (1, 2, 3)) if r != (add, 1, 2, 3)], key=str + ) + + res = run(0, x, assoc_flatteno(x, (add, 1, 2, 3), no_ident=True)) + assert sorted(res, key=str) == sorted( + [(add,) + r for r in assoc_args(add, (1, 2, 3)) if r != (add, 1, 2, 3)], key=str + ) + + +def test_assoc_flatteno_unground(): + + commutative.index.clear() + commutative.facts.clear() + associative.index.clear() + associative.facts.clear() + + add = "add" + mul = "mul" + + fact(commutative, add) + fact(associative, add) + fact(commutative, mul) + fact(associative, mul) -def test_eq_assoccomm(): x, y = var(), var() - ac = "commassoc_op" + xx, yy, zz = var(), var(), var() + op_lv = var() + exp_res_form = ( + (xx, yy), + (etuple(op_lv, xx, yy), etuple(op_lv, xx, yy)), + (etuple(op_lv, xx, yy), etuple(op_lv, xx, yy)), + (etuple(op_lv, xx, etuple(op_lv, yy, zz)), etuple(op_lv, xx, yy, zz)), + ) + + res = run(4, (x, y), assoc_flatteno(x, y)) + for a, b in zip(res, exp_res_form): + s = unify(a, b) + assert s is not False, (a, b) + assert op_lv not in s or (s[op_lv],) in associative.facts + assert all(isvar(i) for i in reify((xx, yy, zz), s)) + + ww = var() + exp_res_form = ( + (etuple(op_lv, xx, etuple(op_lv, yy, zz)), etuple(op_lv, xx, yy, zz)), + (etuple(op_lv, xx, etuple(op_lv, yy, zz)), etuple(op_lv, xx, yy, zz)), + (etuple(op_lv, xx, etuple(op_lv, yy, zz, ww)), etuple(op_lv, xx, yy, zz, ww)), + ) + res = run(3, (x, y), assoc_flatteno(x, y, no_ident=True)) + + for a, b in zip(res, exp_res_form): + assert a[0] != a[1] + s = unify(a, b) + assert s is not False, (a, b) + assert op_lv not in s or (s[op_lv],) in associative.facts + assert all(isvar(i) for i in reify((xx, yy, zz, ww), s)) + + +def test_eq_assoccomm(): + + associative.index.clear() + associative.facts.clear() commutative.index.clear() commutative.facts.clear() + ac = "commassoc_op" + fact(commutative, ac) fact(associative, ac) + x, y = var(), var() + assert run(0, True, eq_assoccomm(1, 1)) == (True,) assert run(0, True, eq_assoccomm((1,), (1,))) == (True,) # assert run(0, True, eq_assoccomm(x, (1,))) == (True,) @@ -437,6 +664,37 @@ def test_eq_assoccomm(): True, ) + assert run(0, (x, y), eq_assoccomm((ac, 2, (ac, 3, 1)), (ac, (ac, 1, x), y))) == ( + (2, 3), + (3, 2), + ) + # assert run(0, (x, y), eq_assoccomm((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) == ( + # (2, 3), + # (3, 2), + # ) + + assert run(0, True, eq_assoccomm((ac, (ac, 1, 2), 3), (ac, 1, 2, 3))) == (True,) + assert run(0, True, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, 2, 3))) == (True,) + assert run(0, True, eq_assoccomm((ac, 1, 1), ("other_op", 1, 1))) == () + + assert run(0, x, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, x, 3))) == (2,) + + +def test_eq_assoccomm_all_variations(): + + associative.index.clear() + associative.facts.clear() + commutative.index.clear() + commutative.facts.clear() + + ac = "commassoc_op" + + fact(commutative, ac) + fact(associative, ac) + + x = var() + + # TODO: Use four arguments to see real associative variation. exp_res = set( ( (ac, 1, 3, 2), @@ -462,23 +720,21 @@ def test_eq_assoccomm(): assert set(run(0, x, eq_assoccomm((ac, 1, (ac, 2, 3)), x))) == exp_res assert set(run(0, x, eq_assoccomm((ac, 1, 3, 2), x))) == exp_res assert set(run(0, x, eq_assoccomm((ac, 2, (ac, 3, 1)), x))) == exp_res - # LHS variations - # assert set(run(0, x, eq_assoccomm(x, (ac, 1, (ac, 2, 3))))) == exp_res - assert run(0, (x, y), eq_assoccomm((ac, 2, (ac, 3, 1)), (ac, (ac, 1, x), y))) == ( - (2, 3), - (3, 2), - ) - # assert run(0, (x, y), eq_assoccomm((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) == ( - # (2, 3), - # (3, 2), - # ) - assert run(0, True, eq_assoccomm((ac, (ac, 1, 2), 3), (ac, 1, 2, 3))) == (True,) - assert run(0, True, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, 2, 3))) == (True,) - assert run(0, True, eq_assoccomm((ac, 1, 1), ("other_op", 1, 1))) == () +def test_eq_assoccomm_unground(): - assert run(0, x, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, x, 3))) == (2,) + associative.index.clear() + associative.facts.clear() + commutative.index.clear() + commutative.facts.clear() + + ac = "commassoc_op" + + fact(commutative, ac) + fact(associative, ac) + + x, y = var(), var() # Both arguments unground op_lv = var() @@ -488,36 +744,33 @@ def test_eq_assoccomm(): (etuple(op_lv, x, y), etuple(op_lv, y, x)), (y, y), ( - etuple(etuple(op_lv, x, y)), - etuple(etuple(op_lv, y, x)), - ), - ( - etuple(op_lv, x, y, z), - etuple(op_lv, etuple(op_lv, x, y), z), + etuple(op_lv, x, y), + etuple(op_lv, y, x), ), + (etuple(op_lv, x, etuple(op_lv, y, z)), etuple(op_lv, x, etuple(op_lv, z, y))), ) for a, b in zip(res, exp_res_form): s = unify(a, b) + assert s is not False, (a, b) assert ( op_lv not in s or (s[op_lv],) in associative.facts or (s[op_lv],) in commutative.facts ) - assert s is not False, (a, b) assert all(isvar(i) for i in reify((x, y, z), s)) -def test_assoccomm_algebra(): - - add = "add" - mul = "mul" +def test_eq_assoccomm_algebra(): commutative.index.clear() commutative.facts.clear() associative.index.clear() associative.facts.clear() + add = "add" + mul = "mul" + fact(commutative, add) fact(associative, add) fact(commutative, mul) @@ -531,7 +784,7 @@ def test_assoccomm_algebra(): assert run(0, (x, y), eq_assoccomm(pattern, expr)) == ((3, 2),) -def test_assoccomm_objects(): +def test_eq_assoccomm_objects(): commutative.index.clear() commutative.facts.clear() @@ -549,3 +802,41 @@ def test_assoccomm_objects(): # results in the following test (i.e. `(3, 3)`). assert run(0, x, eq_assoccomm(Add(1, 2, 3), Add(1, 2, x))) == (3,) assert run(0, x, eq_assoccomm(Add(1, 2, 3), Add(x, 2, 1))) == (3,) + + +@pytest.mark.xfail(strict=False) +@pytest.mark.timeout(5) +def test_eq_assoccom_scaling(): + + commutative.index.clear() + commutative.facts.clear() + associative.index.clear() + associative.facts.clear() + + add = "add" + mul = "mul" + + fact(commutative, add) + fact(associative, add) + fact(commutative, mul) + fact(associative, mul) + + # TODO: Make a low-depth term inequal (e.g. inequal at base) + import random + + from tests.utils import generate_term + + random.seed(2343243) + + test_graph = generate_term((add, mul), range(4), 5) + + # Change the root operator + test_graph_2 = list(test_graph) + test_graph_2[0] = add if test_graph_2[0] == mul else mul + + assert test_graph != test_graph_2 + assert test_graph[1:] == test_graph_2[1:] + + assert run(0, True, eq_assoccomm(test_graph, test_graph_2)) == () + + # TODO: Make a high-depth term inequal diff --git a/tests/test_term.py b/tests/test_term.py index 9c8a1c4..d28df26 100644 --- a/tests/test_term.py +++ b/tests/test_term.py @@ -83,8 +83,10 @@ def __call__(self, *args): def test_TermType(): assert issubclass(type(Add(1, 2)), TermType) assert isinstance(Add(1, 2), TermType) - assert not issubclass(type([1, 2]), TermType) - assert not isinstance([1, 2], TermType) + # assert not issubclass(type([1, 2]), TermType) + # assert not isinstance([1, 2], TermType) + assert issubclass(type([1, 2]), TermType) + assert isinstance([1, 2], TermType) assert not isinstance(ConsType, TermType) assert not issubclass(type(ConsType), TermType) diff --git a/tests/utils.py b/tests/utils.py index e69c6a1..4f04c41 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,5 @@ +import random + from kanren.term import arguments, operator @@ -58,3 +60,29 @@ def operator_Node(t): # @term.register(Operator, Sequence) # def term_Operator(op, args): # return Node(op, args) + + +def generate_term(ops, args, i=10, gen_fn=None): + + if gen_fn is not None: + + gen_res = gen_fn(ops, args, i) + + if gen_res is not None: + return gen_res + + g_op = random.choice(ops) + + if i > 0: + num_sub_graphs = len(args) // 2 + else: + num_sub_graphs = 0 + + g_args = random.sample(args, len(args) - num_sub_graphs) + g_args += [ + generate_term(ops, args, i=i - 1, gen_fn=gen_fn) for s in range(num_sub_graphs) + ] + + random.shuffle(g_args) + + return [g_op] + list(g_args) From 864d0806b15fdc6a70a5d007c48cc847d5560f8c Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:50 -0600 Subject: [PATCH 10/14] Add missing no_ident arguments --- kanren/assoccomm.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/kanren/assoccomm.py b/kanren/assoccomm.py index a877203..33fecd5 100644 --- a/kanren/assoccomm.py +++ b/kanren/assoccomm.py @@ -214,12 +214,11 @@ def op_pred(sub_op): def eq_assoc_args( op, a_args, b_args, n=None, inner_eq=eq, no_ident=False, null_type=etuple ): - """Create a goal that applies associative unification to an operator and two sets of arguments. + """Create a goal that applies associative unification to an operator and two sets of arguments. # noqa: E501 - This is a non-relational utility goal. It does assumes that the op and at - least one set of arguments are ground under the state in which it is - evaluated. - """ # noqa: E501 + This is a non-relational utility goal. It assumes that the op and at least + one set of arguments are ground under the state in which it is evaluated. + """ u_args, v_args = var(), var() def eq_assoc_args_goal(S): @@ -293,7 +292,7 @@ def eq_assoc_args_goal(S): ) -def eq_assoc(u, v, n=None, op_predicate=associative, null_type=etuple): +def eq_assoc(u, v, n=None, op_predicate=associative, null_type=etuple, no_ident=False): """Create a goal for associative unification of two terms. Warning: This goal walks the left-hand argument, `u`, so make that argument @@ -314,10 +313,10 @@ def eq_assoc(u, v, n=None, op_predicate=associative, null_type=etuple): def assoc_args_unique(a, b, op, **kwargs): return eq_assoc_args(op, a, b, no_ident=True, null_type=null_type) - return term_walko(op_predicate, assoc_args_unique, u, v, n=n) + return term_walko(op_predicate, assoc_args_unique, u, v, n=n, no_ident=no_ident) -def eq_comm(u, v, op_predicate=commutative, null_type=etuple): +def eq_comm(u, v, op_predicate=commutative, null_type=etuple, no_ident=False): """Create a goal for commutative equality. Warning: This goal walks the left-hand argument, `u`, so make that argument @@ -339,7 +338,7 @@ def eq_comm(u, v, op_predicate=commutative, null_type=etuple): def permuteo_unique(x, y, op, **kwargs): return permuteo(x, y, no_ident=True, default_ConsNull=null_type) - return term_walko(op_predicate, permuteo_unique, u, v) + return term_walko(op_predicate, permuteo_unique, u, v, no_ident=no_ident) def eq_assoccomm(u, v, null_type=etuple): From b42e522dc0a39a72ec660fbb88aa82ce6bde9716 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:51 -0600 Subject: [PATCH 11/14] Add traces to dbgo --- kanren/core.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/kanren/core.py b/kanren/core.py index 1d60118..7db58b8 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -250,22 +250,36 @@ def run(n, x, *goals, results_filter=None): return tuple(take(n, results)) -def dbgo(*args, msg=None): # pragma: no cover - """Construct a goal that sets a debug trace and prints reified arguments.""" +def dbgo(*args, msg=None, pdb=False, print_asap=True, trace=True): # pragma: no cover + """Construct a goal that prints reified arguments and, optionally, sets a debug trace.""" from pprint import pprint + from unification import var + + trace_var = var("__dbgo_trace") def dbgo_goal(S): - nonlocal args - args = reify(args, S) + nonlocal args, msg, pdb, print_asap, trace_var, trace + + args_rf, trace_rf = reify((args, trace_var), S) + + if trace: + S = S.copy() + if isvar(trace_rf): + S[trace_var] = [(msg, tuple(str(a) for a in args_rf))] + else: + trace_rf.append((msg, tuple(str(a) for a in args_rf))) + S[trace_var] = trace_rf - if msg is not None: - print(msg) + if print_asap: + if msg is not None: + print(msg) + pprint(args_rf) - pprint(args) + if pdb: + import pdb - import pdb + pdb.set_trace() - pdb.set_trace() yield S return dbgo_goal From 1630006713213e0579ac607d62921968457c1952 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:52 -0600 Subject: [PATCH 12/14] Move ground ordering to mapo; make walko variadic; preserve type in ground_order_seqs --- kanren/core.py | 9 +++- kanren/graph.py | 127 ++++++++++++++++++++++++-------------------- tests/test_core.py | 2 + tests/test_graph.py | 105 +++++++++++++++++++++++++++++++++--- 4 files changed, 175 insertions(+), 68 deletions(-) diff --git a/kanren/core.py b/kanren/core.py index 7db58b8..aa89f69 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -176,7 +176,11 @@ def ground_order_seqs_goal(S): ): in_seqs_ord = zip(*sorted(zip(*in_seqs_rf), key=partial(key_fn, S))) - S_new = unify(list(out_seqs_rf), list(in_seqs_ord), S) + S_new = unify( + list(out_seqs_rf), + [type(j)(i) for i, j in zip(in_seqs_ord, in_seqs_rf)], + S, + ) if S_new is not False: yield S_new @@ -251,8 +255,9 @@ def run(n, x, *goals, results_filter=None): def dbgo(*args, msg=None, pdb=False, print_asap=True, trace=True): # pragma: no cover - """Construct a goal that prints reified arguments and, optionally, sets a debug trace.""" + """Construct a goal that prints reified arguments and, optionally, sets a debug trace.""" # noqa: E501 from pprint import pprint + from unification import var trace_var = var("__dbgo_trace") diff --git a/kanren/graph.py b/kanren/graph.py index eec5ea3..d4f93eb 100644 --- a/kanren/graph.py +++ b/kanren/graph.py @@ -3,13 +3,13 @@ from etuples import etuple from unification import isvar, reify, var -from .core import Zzz, conde, eq, fail, lall, succeed +from .core import Zzz, conde, eq, fail, ground_order_seqs, lall, succeed from .goals import conso, nullo from .term import applyo def mapo(*args, null_res=True, **kwargs): - """Apply a relation to corresponding elements in two sequences and succeed if the relation succeeds for all sets of elements. # noqa: E501 + """Apply a goal to corresponding elements in two sequences and succeed if the goal succeeds for all sets of elements. # noqa: E501 See `map_anyo` for parameter descriptions. """ @@ -17,40 +17,43 @@ def mapo(*args, null_res=True, **kwargs): def map_anyo( - relation, + goal, *args, null_type=list, null_res=False, + use_ground_ordering=True, _first=True, _any_succeed=False, - **kwargs + **kwargs, ): - """Apply a relation to corresponding elements across sequences and succeed if at least one set of elements succeeds. + """Apply a goal to corresponding elements across sequences and succeed if at least one set of elements succeeds. Parameters ---------- - relation: Callable + goal: Callable The goal to apply across elements (`car`s, specifically) of `args`. *args: Sequence Argument list containing terms that are walked and evaluated as - `relation(car(a_1), car(a_2), ...)`. + `goal(car(a_1), car(a_2), ...)`. null_type: optional An object that's a valid cdr for the collection type desired. If `False` (i.e. the default value), the cdr will be inferred from the inputs, or defaults to an empty list. null_res: bool Succeed on empty lists. + use_ground_ordering: bool + Order arguments by their "groundedness" before recursing. _first: bool Indicate whether or not this is the first iteration in a call to this goal constructor (in contrast to a recursive call). This is not a user-level parameter. _any_succeed: bool or None Indicate whether or not an iteration has succeeded in a recursive call - to this goal, or, if `None`, indicate that only the relation against the + to this goal, or, if `None`, indicate that only the goal against the `cars` should be checked (i.e. no "any" functionality). This is not a user-level parameter. **kwargs: dict - Keyword arguments to `relation`. + Keyword arguments to `goal`. """ # noqa: E501 cars = tuple(var() for a in args) @@ -58,16 +61,16 @@ def map_anyo( conde_branches = [ [ - Zzz(relation, *cars, **kwargs), + Zzz(goal, *cars, **kwargs), Zzz( map_anyo, - relation, + goal, *cdrs, null_type=null_type, null_res=null_res, _first=False, _any_succeed=True if _any_succeed is not None else None, - **kwargs + **kwargs, ), ] ] @@ -79,22 +82,30 @@ def map_anyo( + [ Zzz( map_anyo, - relation, + goal, *cdrs, null_type=null_type, null_res=null_res, _first=False, _any_succeed=_any_succeed, - **kwargs + **kwargs, ), ] ) else: nullo_condition = not _first or null_res + if use_ground_ordering: + args_ord = tuple(var() for t in args) + ground_order_goal = ground_order_seqs(args, args_ord) + else: + args_ord = args + ground_order_goal = succeed + return conde( [nullo(*args, default_ConsNull=null_type) if nullo_condition else fail], - [conso(car, cdr, arg) for car, cdr, arg in zip(cars, cdrs, args)] + [ground_order_goal] + + [conso(car, cdr, arg) for car, cdr, arg in zip(cars, cdrs, args_ord)] + [conde(*conde_branches)], ) @@ -109,15 +120,19 @@ def eq_length(u, v, default_ConsNull=list): return mapo(vararg_success, u, v, null_type=default_ConsNull) -def reduceo(relation, in_term, out_term, *args, **kwargs): - """Relate a term and the fixed-point of that term under a given relation. +def reduceo(goal, in_term, out_term, *args, **kwargs): + """Construct a goal that yields the fixed-point of another goal. - This includes the "identity" relation. + It simply tries to order the implicit `conde` recursion branches so that they + produce the fixed-point value first. All goals that follow are the reductions + leading up to the fixed-point. + + FYI: The results will include `eq(in_term, out_term)`. """ def reduceo_goal(s): - nonlocal in_term, out_term, relation, args, kwargs + nonlocal in_term, out_term, goal, args, kwargs in_term_rf, out_term_rf = reify((in_term, out_term), s) @@ -125,19 +140,19 @@ def reduceo_goal(s): term_rdcd = var() # Are we working "backward" and (potentially) "expanding" a graph - # (e.g. when the relation is a reduction rule)? + # (e.g. when the goal is a reduction rule)? is_expanding = isvar(in_term_rf) - # One application of the relation assigned to `term_rdcd` - single_apply_g = relation(in_term_rf, term_rdcd, *args, **kwargs) + # One application of the goal assigned to `term_rdcd` + single_apply_g = goal(in_term_rf, term_rdcd, *args, **kwargs) # Assign/equate (unify, really) the result of a single application to # the "output" term. single_res_g = eq(term_rdcd, out_term_rf) - # Recurse into applications of the relation (well, produce a goal that + # Recurse into applications of the goal (well, produce a goal that # will do that) - another_apply_g = reduceo(relation, term_rdcd, out_term_rf, *args, **kwargs) + another_apply_g = reduceo(goal, term_rdcd, out_term_rf, *args, **kwargs) # We want the fixed-point value to show up in the stream output # *first*, but that requires some checks. @@ -166,55 +181,51 @@ def reduceo_goal(s): def walko( goal, - graph_in, - graph_out, - rator_goal=None, + *terms, + pre_process_fn=None, null_type=etuple, map_rel=partial(map_anyo, null_res=True), ): - """Apply a binary relation between all nodes in two graphs. - - When `rator_goal` is used, the graphs are treated as term graphs, and the - multi-functions `rator`, `rands`, and `apply` are used to walk the graphs. - Otherwise, the graphs must be iterable according to `map_anyo`. + """Apply a goal between all nodes in a set of terms. Parameters ---------- goal: callable - A goal that is applied to all terms in the graph. - graph_in: object - The graph for which the left-hand side of a binary relation holds. - graph_out: object - The graph for which the right-hand side of a binary relation holds. - rator_goal: callable (default None) - A goal that is applied to the rators of a graph. When specified, - `goal` is only applied to rands and it must succeed along with the - rator goal in order to descend into sub-terms. + A goal that is applied to all terms and their sub-terms. + *terms: Sequence of objects + The terms to be walked. + pre_process_fn: callable (default None) + A goal with a signature of the form `(old_terms, new_terms)`, where + each argument is a list of corresponding terms. + This goal can be used to transform terms before walking them. null_type: type - The collection type used when it is not fully determined by the graph + The collection type used when it is not fully determined by the `terms` arguments. map_rel: callable - The map relation used to apply `goal` to a sub-graph. + The map goal used to apply `goal` to corresponding sub-terms. """ - rator_in, rands_in, rator_out, rands_out = var(), var(), var(), var() + if pre_process_fn is not None: + terms_pp = tuple(var() for t in terms) + pre_process_goal = pre_process_fn(*(terms + terms_pp)) + else: + terms_pp = terms + pre_process_goal = succeed _walko = partial( - walko, goal, rator_goal=rator_goal, null_type=null_type, map_rel=map_rel + walko, + goal, + pre_process_fn=pre_process_fn, + null_type=null_type, + map_rel=map_rel, ) - return conde( - [ - Zzz(goal, graph_in, graph_out), - ], - [ - applyo(rator_in, rands_in, graph_in), - applyo(rator_out, rands_out, graph_out), - Zzz(rator_goal, rator_in, rator_out), - Zzz(map_rel, _walko, rands_in, rands_out, null_type=null_type), - ] - if rator_goal is not None - else [Zzz(map_rel, _walko, graph_in, graph_out, null_type=null_type)], + return lall( + pre_process_goal, + conde( + [Zzz(goal, *terms_pp)], + [Zzz(map_rel, _walko, *terms_pp, null_type=null_type)], + ), ) @@ -226,7 +237,7 @@ def term_walko( null_type=etuple, no_ident=False, format_step=None, - **kwargs + **kwargs, ): """Construct a goal for walking a term graph. diff --git a/tests/test_core.py b/tests/test_core.py index 8444c44..eefbd52 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -298,3 +298,5 @@ def test_ground_order_seq(): assert res == ((((1, b), a), (2, b)),) res = run(0, (x, y), ground_order_seqs(["abc", "def"], [x, y])) assert res == (("abc", "def"),) + res = run(0, (x, y), ground_order_seqs([[1, 2], (1, 2)], [x, y])) + assert res == (([1, 2], (1, 2)),) diff --git a/tests/test_graph.py b/tests/test_graph.py index 51e50da..621b51f 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -11,7 +11,9 @@ from kanren import conde, eq, lall, run from kanren.constraints import isinstanceo +from kanren.core import Zzz from kanren.graph import eq_length, map_anyo, mapo, reduceo, walko +from kanren.term import applyo class OrderedFunction(object): @@ -71,11 +73,23 @@ def single_math_reduceo(expanded_term, reduced_term): math_reduceo = partial(reduceo, single_math_reduceo) + +def walko_term_map_rel(walko_goal, x, y, **kwargs): + rator_in, rator_out = var(), var() + rands_in, rands_out = var(), var() + + return lall( + applyo(rator_in, rands_in, x), + applyo(rator_out, rands_out, y), + eq(rator_in, rator_out), + Zzz(map_anyo, walko_goal, rands_in, rands_out, null_res=False, **kwargs), + ) + + walko_term = partial( walko, - rator_goal=eq, null_type=ExpressionTuple, - map_rel=partial(map_anyo, null_res=False), + map_rel=walko_term_map_rel, ) @@ -176,7 +190,7 @@ def test_map_anyo_types(): def test_map_anyo_misc(): - q_lv = var("q") + q_lv = var() res = run(0, q_lv, map_anyo(eq, [1, 2, 3], [1, 2, 3])) # TODO: Remove duplicate results @@ -201,13 +215,14 @@ def one_to_threeo(x, y): test_res = run(4, q_lv, map_anyo(math_reduceo, [1, etuple(add, 2, 2)], q_lv)) assert test_res == ([1, etuple(mul, 2, 2)],) - test_res = run(4, q_lv, map_anyo(math_reduceo, q_lv, var("z"))) + z = var() + test_res = run(4, q_lv, map_anyo(math_reduceo, q_lv, z)) assert all(isinstance(r, list) for r in test_res) - test_res = run(4, q_lv, map_anyo(math_reduceo, q_lv, var("z"), tuple)) + test_res = run(4, q_lv, map_anyo(math_reduceo, q_lv, z, null_type=tuple)) assert all(isinstance(r, tuple) for r in test_res) - x, y, z = var(), var(), var() + x, y = var(), var() def test_bin(a, b): return conde([eq(a, 1), eq(b, 2)]) @@ -231,6 +246,21 @@ def test_bin(a, b): assert s is not False assert all(isvar(i) for i in reify((x, y, z), s)) + # With ground ordering, this function should only be called once + n = 0 + + def eq_count(x, y): + nonlocal n + n += 1 + return eq(x, y) + + run(0, q_lv, map_anyo(eq_count, [x, y, 3], [y, x, 2])) + assert n == 1 + + n = 0 + run(0, q_lv, map_anyo(eq_count, [x, y, 3], [y, x, 2], use_ground_ordering=False)) + assert n == 3 + @pytest.mark.parametrize( "test_input, test_output", @@ -323,7 +353,7 @@ def test_map_anyo_reverse(): def test_walko_misc(): - q_lv = var(prefix="q") + q_lv = var() expr = etuple(add, etuple(mul, 2, 1), etuple(add, 1, 1)) res = run(0, q_lv, walko(eq, expr, expr)) @@ -353,6 +383,18 @@ def one_to_threeo(x, y): etuple(), ) + def map_rel(walk_g, x, y, **kwargs): + rator_in, rator_out = var(), var() + rands_in, rands_out = var(), var() + + return lall( + applyo(rator_in, rands_in, x), + applyo(rator_out, rands_out, y), + eq(rator_in, add), + eq(rator_out, add), + map_anyo(walk_g, x, y, **kwargs), + ) + res = run( 1, q_lv, @@ -366,7 +408,7 @@ def one_to_threeo(x, y): ), q_lv, # Only descend into `add` terms - rator_goal=lambda x, y: lall(eq(x, add), eq(y, add)), + map_rel=map_rel, ), ) @@ -376,6 +418,53 @@ def one_to_threeo(x, y): ), ) + # Now, we check that the `use_ground_order_seqs` option prevents infinite + # loops in `walko`. + + # This would go on forever between the two variable terms, even though + # the `(2, 3)` pair would cause it fail (if it were ever reached) + # run(1, True, walko(eq, [var(), 2], [var(), 3])) + run(1, True, walko(eq, [var(), 2], [var(), 3])) + + # Only walk rators for terms with the same car + def same_op(x, y, a, b): + rator_in = var() + rands_in, rands_out = var(), var() + + return conde( + [ + applyo(rator_in, rands_in, x), + applyo(rator_in, rands_out, y), + eq(a, rands_in), + eq(b, rands_out), + ], + [eq(a, None), eq(b, None)], + ) + + def one_to_threeo(x, y): + return conde([eq(x, 1), eq(y, 3)], [eq(x, None), eq(y, None)]) + + # This won't work without the pre-processing function, because + # `one_to_threeo` will fail when given `add, add` arguments + assert ( + run(0, True, walko(one_to_threeo, [add, 1, 1], [add, 3, 3], map_rel=mapo)) == () + ) + + assert ( + run( + 1, + True, + walko( + one_to_threeo, + [add, 1, 1], + [add, 3, 3], + pre_process_fn=same_op, + map_rel=mapo, + ), + ) + == (True,) + ) + @pytest.mark.parametrize( "test_input, test_output", From 8b2f2bf17bfcc900b35a99ed433b4211845dd620 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:54 -0600 Subject: [PATCH 13/14] Add an all variations commutativity test --- tests/test_assoccomm.py | 101 ++++++++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/tests/test_assoccomm.py b/tests/test_assoccomm.py index b0fb644..cb09eed 100644 --- a/tests/test_assoccomm.py +++ b/tests/test_assoccomm.py @@ -29,8 +29,6 @@ def results(g, s=None): def test_eq_comm(): - x, y, z = var(), var(), var() - commutative.facts.clear() commutative.index.clear() @@ -38,6 +36,8 @@ def test_eq_comm(): fact(commutative, comm_op) + x, y, z = var(), var(), var() + assert run(0, True, eq_comm(1, 1)) == (True,) assert run(0, True, eq_comm((comm_op, 1, 2, 3), (comm_op, 1, 2, 3))) == (True,) @@ -59,20 +59,6 @@ def test_eq_comm(): assert not run(0, True, eq_comm((comm_op, 1, 2, 1), (comm_op, 1, 2, 3))) assert not run(0, True, eq_comm(("op", 1, 2, 3), (comm_op, 1, 2, 3))) - # Test for variable args - res = run(4, (x, y), eq_comm(x, y)) - exp_res_form = ( - (etuple(comm_op, x, y), etuple(comm_op, y, x)), - (x, y), - (etuple(etuple(comm_op, x, y)), etuple(etuple(comm_op, y, x))), - (etuple(comm_op, x, y, z), etuple(comm_op, x, z, y)), - ) - - for a, b in zip(res, exp_res_form): - s = unify(a, b) - assert s is not False - assert all(isvar(i) for i in reify((x, y, z), s)) - # Make sure it can unify single elements assert (3,) == run(0, x, eq_comm((comm_op, 1, 2, 3), (comm_op, 2, x, 1))) @@ -112,9 +98,53 @@ def test_eq_comm(): e2 = (y, (comm_op, 1, x)) assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) - e1 = (comm_op, 2, (comm_op, 3, 1)) - e2 = (comm_op, (comm_op, 1, x), y) - assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) + +def test_eq_comm_all_variations(): + commutative.facts.clear() + commutative.index.clear() + + comm_op = "comm_op" + + fact(commutative, comm_op) + + expected_res = { + (comm_op, 1, (comm_op, 2, (comm_op, 3, 4))), + (comm_op, 1, (comm_op, 2, (comm_op, 4, 3))), + (comm_op, 1, (comm_op, (comm_op, 3, 4), 2)), + (comm_op, 1, (comm_op, (comm_op, 4, 3), 2)), + (comm_op, (comm_op, 2, (comm_op, 3, 4)), 1), + (comm_op, (comm_op, 2, (comm_op, 4, 3)), 1), + (comm_op, (comm_op, (comm_op, 3, 4), 2), 1), + (comm_op, (comm_op, (comm_op, 4, 3), 2), 1), + } + + x = var() + res = run(0, x, eq_comm((comm_op, 1, (comm_op, 2, (comm_op, 3, 4))), x)) + assert sorted(res, key=str) == sorted(expected_res, key=str) + + +def test_eq_comm_unground(): + commutative.facts.clear() + commutative.index.clear() + + comm_op = "comm_op" + + fact(commutative, comm_op) + + x, y, z = var(), var(), var() + # Test for variable args + res = run(4, (x, y), eq_comm(x, y)) + exp_res_form = ( + (etuple(comm_op, x, y), etuple(comm_op, y, x)), + (x, y), + (etuple(etuple(comm_op, x, y)), etuple(etuple(comm_op, y, x))), + (etuple(comm_op, x, y, z), etuple(comm_op, x, z, y)), + ) + + for a, b in zip(res, exp_res_form): + s = unify(a, b) + assert s is not False + assert all(isvar(i) for i in reify((x, y, z), s)) @pytest.mark.xfail(reason="`applyo`/`buildo` needs to be a constraint.", strict=True) @@ -423,14 +453,6 @@ def test_eq_assoc(): expr2 = (assoc_op, 1, 2, (assoc_op, x, 5, 6)) assert run(0, x, eq_assoc(expr1, expr2, n=2)) == ((assoc_op, 3, 4),) - # TODO: Need groundedness ordering for this - # res = run(0, x, eq_assoc(x, (assoc_op, 1, 2, 3), n=2)) - # assert res == ( - # (assoc_op, (assoc_op, 1, 2), 3), - # (assoc_op, 1, 2, 3), - # (assoc_op, 1, (assoc_op, 2, 3)), - # ) - @pytest.mark.xfail(strict=False) def test_eq_assoc_cons(): @@ -458,21 +480,20 @@ def test_eq_assoc_all_variations(): fact(associative, assoc_op) x = var() + # Normalized, our results are left-associative, i.e. + # (assoc_op, (assoc_op, (assoc_op, 1, 2), 3), 4) == (assoc_op, 1, 2, 3, 4) expected_res = { - # Normalized, our results are left-associative, i.e. - # (assoc_op, (assoc_op, (assoc_op, 1, 2), 3), 4), - # is equal to the following: - (assoc_op, 1, 2, 3, 4), + (assoc_op, (assoc_op, (assoc_op, 1, 2), 3), 4), # Missing + (assoc_op, (assoc_op, 1, (assoc_op, 2, 3)), 4), # Missing + (assoc_op, (assoc_op, 1, 2), (assoc_op, 3, 4)), (assoc_op, (assoc_op, 1, 2), 3, 4), - (assoc_op, 1, (assoc_op, 2, 3), 4), - (assoc_op, 1, 2, (assoc_op, 3, 4)), (assoc_op, (assoc_op, 1, 2, 3), 4), + (assoc_op, 1, (assoc_op, (assoc_op, 2, 3), 4)), # Missing + (assoc_op, 1, (assoc_op, 2, (assoc_op, 3, 4))), # Missing + (assoc_op, 1, (assoc_op, 2, 3), 4), (assoc_op, 1, (assoc_op, 2, 3, 4)), - (assoc_op, (assoc_op, 1, 2), (assoc_op, 3, 4)), - (assoc_op, (assoc_op, (assoc_op, 1, 2), 3), 4), - (assoc_op, (assoc_op, 1, (assoc_op, 2, 3)), 4), - (assoc_op, 1, (assoc_op, (assoc_op, 2, 3), 4)), - (assoc_op, 1, (assoc_op, 2, (assoc_op, 3, 4))), + (assoc_op, 1, 2, (assoc_op, 3, 4)), + (assoc_op, 1, 2, 3, 4), } res = run(0, x, eq_assoc((assoc_op, 1, 2, 3, 4), x)) assert sorted(res, key=str) == sorted(expected_res, key=str) @@ -806,7 +827,7 @@ def test_eq_assoccomm_objects(): @pytest.mark.xfail(strict=False) @pytest.mark.timeout(5) -def test_eq_assoccom_scaling(): +def test_eq_assoccomm_scaling(): commutative.index.clear() commutative.facts.clear() @@ -821,7 +842,6 @@ def test_eq_assoccom_scaling(): fact(commutative, mul) fact(associative, mul) - # TODO: Make a low-depth term inequal (e.g. inequal at base) import random from tests.utils import generate_term @@ -830,6 +850,7 @@ def test_eq_assoccom_scaling(): test_graph = generate_term((add, mul), range(4), 5) + # Make a low-depth term inequal (e.g. inequal at base): # Change the root operator test_graph_2 = list(test_graph) test_graph_2[0] = add if test_graph_2[0] == mul else mul From ebba96d53c692109f69f00316b4431b81ed9cca2 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Sun, 14 Nov 2021 13:12:55 -0600 Subject: [PATCH 14/14] Make assoccomm all variations tests more robust --- tests/test_assoccomm.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/test_assoccomm.py b/tests/test_assoccomm.py index cb09eed..8b2c39b 100644 --- a/tests/test_assoccomm.py +++ b/tests/test_assoccomm.py @@ -119,8 +119,10 @@ def test_eq_comm_all_variations(): } x = var() - res = run(0, x, eq_comm((comm_op, 1, (comm_op, 2, (comm_op, 3, 4))), x)) - assert sorted(res, key=str) == sorted(expected_res, key=str) + + for s in expected_res: + res = run(0, x, eq_comm(s, x)) + assert sorted(res, key=str) == sorted(expected_res, key=str) def test_eq_comm_unground(): @@ -495,8 +497,10 @@ def test_eq_assoc_all_variations(): (assoc_op, 1, 2, (assoc_op, 3, 4)), (assoc_op, 1, 2, 3, 4), } - res = run(0, x, eq_assoc((assoc_op, 1, 2, 3, 4), x)) - assert sorted(res, key=str) == sorted(expected_res, key=str) + + for s in expected_res: + res = run(0, x, eq_assoc(s, x)) + assert sorted(res, key=str) == sorted(expected_res, key=str) def test_eq_assoc_unground(): @@ -716,7 +720,7 @@ def test_eq_assoccomm_all_variations(): x = var() # TODO: Use four arguments to see real associative variation. - exp_res = set( + expected_res = set( ( (ac, 1, 3, 2), (ac, 1, 2, 3), @@ -738,9 +742,13 @@ def test_eq_assoccomm_all_variations(): (ac, (ac, 2, 1), 3), ) ) - assert set(run(0, x, eq_assoccomm((ac, 1, (ac, 2, 3)), x))) == exp_res - assert set(run(0, x, eq_assoccomm((ac, 1, 3, 2), x))) == exp_res - assert set(run(0, x, eq_assoccomm((ac, 2, (ac, 3, 1)), x))) == exp_res + + for s in expected_res: + res = run(0, x, eq_assoccomm(s, x)) + # TODO FIXME: Avoid the extra identity result + res = list(res) + res.remove(s) + assert sorted(res, key=str) == sorted(expected_res, key=str) def test_eq_assoccomm_unground():