diff --git a/kanren/assoccomm.py b/kanren/assoccomm.py index 99edc58..33fecd5 100644 --- a/kanren/assoccomm.py +++ b/kanren/assoccomm.py @@ -4,97 +4,221 @@ accomplishes this through naively trying all possibilities. This was built to be used in the computer algebra systems SymPy and Theano. ->>> from kanren import run, var, fact ->>> from kanren.assoccomm import eq_assoccomm as eq ->>> from kanren.assoccomm import commutative, associative - ->>> # Define some dummy Ops ->>> add = 'add' ->>> mul = 'mul' - ->>> # Declare that these ops are commutative using the facts system ->>> fact(commutative, mul) ->>> fact(commutative, add) ->>> fact(associative, mul) ->>> fact(associative, add) - ->>> # Define some wild variables ->>> x, y = var('x'), var('y') - ->>> # Two expressions to match ->>> pattern = (mul, (add, 1, x), y) # (1 + x) * y ->>> expr = (mul, 2, (add, 3, 1)) # 2 * (3 + 1) - ->>> print(run(0, (x,y), eq(pattern, expr))) -((3, 2),) """ from collections.abc import Sequence from functools import partial from operator import eq as equal from operator import length_hint -from cons.core import ConsPair, car, cdr from etuples import etuple -from toolz import sliding_window from unification import reify, unify, var -from .core import conde, eq, ground_order, lall, succeed +from .core import Zzz, conde, eq, fail, ground_order, lall, lany, succeed from .facts import Relation -from .goals import itero, permuteo +from .goals import conso, itero, permuteo from .graph import term_walko -from .term import term +from .term import TermType, applyo, arguments, operator, term associative = Relation("associative") commutative = Relation("commutative") -def flatten_assoc_args(op_predicate, items): - for i in items: - if isinstance(i, ConsPair) and op_predicate(car(i)): - i_cdr = cdr(i) - if length_hint(i_cdr) > 0: - yield from flatten_assoc_args(op_predicate, i_cdr) +def flatten_assoc_args(op_predicate, term, shallow=True): + """Flatten/normalize a term with an associative operator. + + Parameters + ---------- + op_predicate: callable + A function used to determine the operators to flatten. + items: Sequence + The term to flatten. + shallow: bool (optional) + Indicate whether or not flattening should be done at all depths. + """ + + if not isinstance(term, TermType): + return term + + def _flatten(term): + for i in term: + if isinstance(i, TermType) and op_predicate(operator(i)): + i_cdr = arguments(i) + if length_hint(i_cdr) > 0: + if shallow: + yield from iter(i_cdr) + else: + yield from _flatten(i_cdr) + else: + yield i else: yield i - else: - yield i + term_type = type(arguments(term)) + return term_type(_flatten(term)) + + +def partitions(in_seq, n_parts=None, min_size=1, part_fn=lambda x: x): + """Generate all partitions of a sequence for given numbers of partitions and minimum group sizes. # noqa: E501 + + Parameters + ---------- + in_seq: Sequence + The sequence to be partitioned. + n_parts: int + Number of partitions. `None` means all partitions in `range(2, len(in_seq))`. + min_size: int + The minimum size of a partition. + part_fn: Callable + A function applied to every partition. + """ + + def partition(seq, res): + if ( + n_parts is None + and + # We don't want the original sequence + len(res) > 0 + ) or len(res) + 1 == n_parts: + yield type(in_seq)(res + [part_fn(seq)]) + + if n_parts is not None: + return + + for s in range(min_size, len(seq) + 1 - min_size, 1): + yield from partition(seq[s:], res + [part_fn(seq[:s])]) + + return partition(in_seq, []) -def assoc_args(rator, rands, n, ctor=None): - """Produce all associative argument combinations of rator + rands in n-sized rand groupings. - >>> from kanren.assoccomm import assoc_args - >>> list(assoc_args('op', [1, 2, 3], 2)) - [[['op', 1, 2], 3], [1, ['op', 2, 3]]] - """ # noqa: E501 - assert n > 0 +def assoc_args(rator, rands, n=None, ctor=None): + """Produce all associative argument combinations of rator + rands in n-sized rand groupings. # noqa: E501 - rands_l = list(rands) + The normal/canonical form is left-associative, e.g. + `(op, 1, 2, 3, 4) == (op, (op, (op, 1, 2), 3), 4)` + + Parameters + ---------- + rator: object + The operator that's assumed to be associative. + rands: Sequence + The operands. + n: int (optional) + The number of arguments in the resulting `(op,) + output` terms. + If not specified, all argument sizes are returned. + ctor: callable + The constructor to use when each term is created. + If not specified, the constructor is inferred from `type(rands)`. + """ if ctor is None: ctor = type(rands) - if n == len(rands_l): + if len(rands) <= 2 or n is not None and len(rands) <= n: yield ctor(rands) return - for i, new_rands in enumerate(sliding_window(n, rands_l)): - prefix = rands_l[:i] - new_term = term(rator, ctor(new_rands)) - suffix = rands_l[n + i :] - res = ctor(prefix + [new_term] + suffix) - yield res + def part_fn(x): + if len(x) == 1: + return x[0] + else: + return term(rator, ctor(x)) + + for p in partitions(rands, n, 1, part_fn=part_fn): + yield ctor(p) + + +def assoc_flatteno(a, a_flat, no_ident=False, null_type=etuple): + """Construct a goal that flattens/normalizes terms with associative operators. + + The normal/canonical form is left-associative, e.g. + `(op, 1, 2, 3, 4) == (op, (op, (op, 1, 2), 3), 4)` + + Parameters + ---------- + a: Var or Sequence + The "input" term to flatten. + a_flat: Var or Sequence + The flattened result, or "output", term. + no_ident: bool (optional) + Whether or not to fail if no flattening occurs. + """ + + def assoc_flatteno_goal(S): + nonlocal a, a_flat + + a_rf, a_flat_rf = reify((a, a_flat), S) + + if isinstance(a_rf, TermType) and (operator(a_rf),) in associative.facts: + + a_op = operator(a_rf) + args_rf = arguments(a_rf) + + def op_pred(sub_op): + return sub_op == a_op + + a_flat_rf = term(a_op, flatten_assoc_args(op_pred, args_rf)) + + if a_flat_rf == a_rf and no_ident: + return + + yield from eq(a_flat, a_flat_rf)(S) + + elif ( + isinstance(a_flat_rf, TermType) + and (operator(a_flat_rf),) in associative.facts + ): + + a_flat_op = operator(a_flat_rf) + args_rf = arguments(a_flat_rf) + assoc_args_iter = assoc_args(a_flat_op, args_rf) + + # TODO: There are much better ways to do this `no_ident` check + # (e.g. the `n` argument of `assoc_args` should probably be made to + # work for this) + yield from lany( + fail if no_ident and r is args_rf else applyo(a_flat_op, r, a_rf) + for r in assoc_args_iter + )(S) + + else: + + op = var() + a_rands = var() + a_rands_rands = var() + a_flat_rands = var() + a_flat_rands_rands = var() + + g = conde( + [fail if no_ident else eq(a_rf, a_flat_rf)], + [ + associative(op), + applyo(op, a_rands, a_rf), + applyo(op, a_flat_rands, a_flat_rf), + # There must be at least two rands + conso(var(), a_rands_rands, a_rands), + conso(var(), var(), a_rands_rands), + conso(var(), a_flat_rands_rands, a_flat_rands), + conso(var(), var(), a_flat_rands_rands), + itero( + a_flat_rands, nullo_refs=(a_rands,), default_ConsNull=null_type + ), + Zzz(assoc_flatteno, a_rf, a_flat_rf, no_ident=no_ident), + ], + ) + + yield from g(S) + + return assoc_flatteno_goal def eq_assoc_args( op, a_args, b_args, n=None, inner_eq=eq, no_ident=False, null_type=etuple ): - """Create a goal that applies associative unification to an operator and two sets of arguments. + """Create a goal that applies associative unification to an operator and two sets of arguments. # noqa: E501 - This is a non-relational utility goal. It does assumes that the op and at - least one set of arguments are ground under the state in which it is - evaluated. - """ # noqa: E501 + This is a non-relational utility goal. It assumes that the op and at least + one set of arguments are ground under the state in which it is evaluated. + """ u_args, v_args = var(), var() def eq_assoc_args_goal(S): @@ -118,15 +242,17 @@ def eq_assoc_args_goal(S): u_args_flat = type(u_args_rf)(flatten_assoc_args(op_pred, u_args_rf)) v_args_flat = type(v_args_rf)(flatten_assoc_args(op_pred, v_args_rf)) - if len(u_args_flat) == len(v_args_flat): + u_len, v_len = len(u_args_flat), len(v_args_flat) + if u_len == v_len: g = inner_eq(u_args_flat, v_args_flat) else: - if len(u_args_flat) < len(v_args_flat): + if u_len < v_len: sm_args, lg_args = u_args_flat, v_args_flat + grp_sizes = u_len else: sm_args, lg_args = v_args_flat, u_args_flat + grp_sizes = v_len - grp_sizes = len(lg_args) - len(sm_args) + 1 assoc_terms = assoc_args( op_rf, lg_args, grp_sizes, ctor=type(u_args_rf) ) @@ -149,20 +275,13 @@ def eq_assoc_args_goal(S): u_args_flat = list(flatten_assoc_args(partial(equal, op_rf), u_args_rf)) - if n_rf is not None: - arg_sizes = [n_rf] - else: - arg_sizes = range(2, len(u_args_flat) + (not no_ident)) - - v_ac_args = ( - v_ac_arg - for n_i in arg_sizes + g = conde( + [inner_eq(v_args_rf, v_ac_arg)] for v_ac_arg in assoc_args( - op_rf, u_args_flat, n_i, ctor=type(u_args_rf) + op_rf, u_args_flat, n_rf, ctor=type(u_args_rf) ) if not no_ident or v_ac_arg != u_args_rf ) - g = conde([inner_eq(v_args_rf, v_ac_arg)] for v_ac_arg in v_ac_args) yield from g(S) @@ -173,9 +292,13 @@ def eq_assoc_args_goal(S): ) -def eq_assoc(u, v, n=None, op_predicate=associative, null_type=etuple): +def eq_assoc(u, v, n=None, op_predicate=associative, null_type=etuple, no_ident=False): """Create a goal for associative unification of two terms. + Warning: This goal walks the left-hand argument, `u`, so make that argument + the most ground term; otherwise, it may iterate indefinitely when it should + actually terminate. + >>> from kanren import run, var, fact >>> from kanren.assoccomm import eq_assoc as eq @@ -190,12 +313,16 @@ def eq_assoc(u, v, n=None, op_predicate=associative, null_type=etuple): def assoc_args_unique(a, b, op, **kwargs): return eq_assoc_args(op, a, b, no_ident=True, null_type=null_type) - return term_walko(op_predicate, assoc_args_unique, u, v, n=n) + return term_walko(op_predicate, assoc_args_unique, u, v, n=n, no_ident=no_ident) -def eq_comm(u, v, op_predicate=commutative, null_type=etuple): +def eq_comm(u, v, op_predicate=commutative, null_type=etuple, no_ident=False): """Create a goal for commutative equality. + Warning: This goal walks the left-hand argument, `u`, so make that argument + the most ground term; otherwise, it may iterate indefinitely when it should + actually terminate. + >>> from kanren import run, var, fact >>> from kanren.assoccomm import eq_comm as eq >>> from kanren.assoccomm import commutative, associative @@ -211,34 +338,16 @@ def eq_comm(u, v, op_predicate=commutative, null_type=etuple): def permuteo_unique(x, y, op, **kwargs): return permuteo(x, y, no_ident=True, default_ConsNull=null_type) - return term_walko(op_predicate, permuteo_unique, u, v) - - -def assoc_flatten(a, a_flat): - def assoc_flatten_goal(S): - nonlocal a, a_flat - - a_rf = reify(a, S) - - if isinstance(a_rf, Sequence) and (a_rf[0],) in associative.facts: - - def op_pred(sub_op): - nonlocal S - sub_op_rf = reify(sub_op, S) - return sub_op_rf == a_rf[0] - - a_flat_rf = type(a_rf)(flatten_assoc_args(op_pred, a_rf)) - else: - a_flat_rf = a_rf - - yield from eq(a_flat, a_flat_rf)(S) - - return assoc_flatten_goal + return term_walko(op_predicate, permuteo_unique, u, v, no_ident=no_ident) def eq_assoccomm(u, v, null_type=etuple): """Construct a goal for associative and commutative unification. + Warning: This goal walks the left-hand argument, `u`, so make that argument + the most ground term; otherwise, it may iterate indefinitely when it should + actually terminate. + >>> from kanren.assoccomm import eq_assoccomm as eq >>> from kanren.assoccomm import commutative, associative >>> from kanren import fact, run, var @@ -275,6 +384,6 @@ def eq_assoccomm_step(a, b, op): eq_assoccomm_step, u, v, - format_step=assoc_flatten, + format_step=assoc_flatteno, no_ident=False, ) diff --git a/kanren/core.py b/kanren/core.py index 3da6b60..aa89f69 100644 --- a/kanren/core.py +++ b/kanren/core.py @@ -1,12 +1,13 @@ -from collections.abc import Generator, Sequence +from collections.abc import Generator, Mapping, Sequence from functools import partial, reduce from itertools import tee from operator import length_hint +from statistics import mean -from cons.core import ConsPair +from cons.core import ConsPair, car, cdr +from multipledispatch import dispatch from toolz import interleave, take from unification import isvar, reify, unify -from unification.core import isground def fail(s): @@ -113,18 +114,25 @@ def conde(*goals): lany = ldisj -def ground_order_key(S, x): +@dispatch(Mapping, object) +def shallow_ground_order_key(S, x): if isvar(x): - return 2 - elif isground(x, S): - return -1 - elif issubclass(type(x), ConsPair): - return 1 - else: - return 0 - - -def ground_order(in_args, out_args): + return 10 + elif isinstance(x, ConsPair): + val = 0 + val += 1 if isvar(car(x)) else 0 + cdr_x = cdr(x) + if issubclass(type(x), ConsPair): + val += 2 if isvar(cdr_x) else 0 + elif len(cdr_x) == 1: + val += 1 if isvar(cdr_x[0]) else 0 + elif len(cdr_x) > 1: + val += mean(1.0 if isvar(i) else 0.0 for i in cdr_x) + return val + return 0 + + +def ground_order(in_args, out_args, key_fn=shallow_ground_order_key): """Construct a non-relational goal that orders a list of terms based on groundedness (grounded precede ungrounded).""" # noqa: E501 def ground_order_goal(S): @@ -134,7 +142,7 @@ def ground_order_goal(S): S_new = unify( list(out_args_rf) if isinstance(out_args_rf, Sequence) else out_args_rf, - sorted(in_args_rf, key=partial(ground_order_key, S)), + sorted(in_args_rf, key=partial(key_fn, S)), S, ) @@ -144,6 +152,48 @@ def ground_order_goal(S): return ground_order_goal +def ground_order_seqs(in_seqs, out_seqs, key_fn=shallow_ground_order_key): + """Construct a non-relational goal that orders lists of sequences based on the groundedness of their corresponding terms. # noqa: E501 + + >>> from unification import var + >>> x, y = var('x'), var('y') + >>> a, b = var('a'), var('b') + >>> run(0, (x, y), ground_order_seqs([(a, b), (b, 2)], [x, y])) + (((~b, ~a), (2, ~b)),) + """ + + def ground_order_seqs_goal(S): + nonlocal in_seqs, out_seqs, key_fn + + in_seqs_rf, out_seqs_rf = reify((in_seqs, out_seqs), S) + + if ( + not any(isinstance(s, str) for s in in_seqs_rf) + and reduce( + lambda x, y: x == y and y, (length_hint(s, -1) for s in in_seqs_rf) + ) + > 0 + ): + + in_seqs_ord = zip(*sorted(zip(*in_seqs_rf), key=partial(key_fn, S))) + S_new = unify( + list(out_seqs_rf), + [type(j)(i) for i, j in zip(in_seqs_ord, in_seqs_rf)], + S, + ) + + if S_new is not False: + yield S_new + else: + + S_new = unify(out_seqs_rf, in_seqs_rf, S) + + if S_new is not False: + yield S_new + + return ground_order_seqs_goal + + def ifa(g1, g2): """Create a goal operator that returns the first stream unless it fails.""" @@ -204,22 +254,37 @@ def run(n, x, *goals, results_filter=None): return tuple(take(n, results)) -def dbgo(*args, msg=None): # pragma: no cover - """Construct a goal that sets a debug trace and prints reified arguments.""" +def dbgo(*args, msg=None, pdb=False, print_asap=True, trace=True): # pragma: no cover + """Construct a goal that prints reified arguments and, optionally, sets a debug trace.""" # noqa: E501 from pprint import pprint + from unification import var + + trace_var = var("__dbgo_trace") + def dbgo_goal(S): - nonlocal args - args = reify(args, S) + nonlocal args, msg, pdb, print_asap, trace_var, trace + + args_rf, trace_rf = reify((args, trace_var), S) + + if trace: + S = S.copy() + if isvar(trace_rf): + S[trace_var] = [(msg, tuple(str(a) for a in args_rf))] + else: + trace_rf.append((msg, tuple(str(a) for a in args_rf))) + S[trace_var] = trace_rf - if msg is not None: - print(msg) + if print_asap: + if msg is not None: + print(msg) + pprint(args_rf) - pprint(args) + if pdb: + import pdb - import pdb + pdb.set_trace() - pdb.set_trace() yield S return dbgo_goal diff --git a/kanren/graph.py b/kanren/graph.py index 3fad195..d4f93eb 100644 --- a/kanren/graph.py +++ b/kanren/graph.py @@ -3,80 +3,110 @@ from etuples import etuple from unification import isvar, reify, var -from .core import Zzz, conde, eq, fail, ground_order, lall, succeed +from .core import Zzz, conde, eq, fail, ground_order_seqs, lall, succeed from .goals import conso, nullo from .term import applyo -def mapo(relation, a, b, null_type=list, null_res=True, first=True): - """Apply a relation to corresponding elements in two sequences and succeed if the relation succeeds for all pairs.""" # noqa: E501 +def mapo(*args, null_res=True, **kwargs): + """Apply a goal to corresponding elements in two sequences and succeed if the goal succeeds for all sets of elements. # noqa: E501 - b_car, b_cdr = var(), var() - a_car, a_cdr = var(), var() - - return conde( - [nullo(a, b, default_ConsNull=null_type) if (not first or null_res) else fail], - [ - conso(a_car, a_cdr, a), - conso(b_car, b_cdr, b), - Zzz(relation, a_car, b_car), - Zzz(mapo, relation, a_cdr, b_cdr, null_type=null_type, first=False), - ], - ) + See `map_anyo` for parameter descriptions. + """ + return map_anyo(*args, null_res=null_res, _first=True, _any_succeed=None, **kwargs) def map_anyo( - relation, a, b, null_type=list, null_res=False, first=True, any_succeed=False + goal, + *args, + null_type=list, + null_res=False, + use_ground_ordering=True, + _first=True, + _any_succeed=False, + **kwargs, ): - """Apply a relation to corresponding elements in two sequences and succeed if at least one pair succeeds. + """Apply a goal to corresponding elements across sequences and succeed if at least one set of elements succeeds. Parameters ---------- + goal: Callable + The goal to apply across elements (`car`s, specifically) of `args`. + *args: Sequence + Argument list containing terms that are walked and evaluated as + `goal(car(a_1), car(a_2), ...)`. null_type: optional An object that's a valid cdr for the collection type desired. If `False` (i.e. the default value), the cdr will be inferred from the inputs, or defaults to an empty list. + null_res: bool + Succeed on empty lists. + use_ground_ordering: bool + Order arguments by their "groundedness" before recursing. + _first: bool + Indicate whether or not this is the first iteration in a call to this + goal constructor (in contrast to a recursive call). + This is not a user-level parameter. + _any_succeed: bool or None + Indicate whether or not an iteration has succeeded in a recursive call + to this goal, or, if `None`, indicate that only the goal against the + `cars` should be checked (i.e. no "any" functionality). + This is not a user-level parameter. + **kwargs: dict + Keyword arguments to `goal`. """ # noqa: E501 - b_car, b_cdr = var(), var() - a_car, a_cdr = var(), var() + cars = tuple(var() for a in args) + cdrs = tuple(var() for a in args) - return conde( - [ - nullo(a, b, default_ConsNull=null_type) - if (any_succeed or (first and null_res)) - else fail - ], + conde_branches = [ [ - conso(a_car, a_cdr, a), - conso(b_car, b_cdr, b), - conde( - [ - Zzz(relation, a_car, b_car), - Zzz( - map_anyo, - relation, - a_cdr, - b_cdr, - null_type=null_type, - any_succeed=True, - first=False, - ), - ], - [ - eq(a_car, b_car), - Zzz( - map_anyo, - relation, - a_cdr, - b_cdr, - null_type=null_type, - any_succeed=any_succeed, - first=False, - ), - ], + Zzz(goal, *cars, **kwargs), + Zzz( + map_anyo, + goal, + *cdrs, + null_type=null_type, + null_res=null_res, + _first=False, + _any_succeed=True if _any_succeed is not None else None, + **kwargs, ), - ], + ] + ] + + if _any_succeed is not None: + nullo_condition = _any_succeed or (_first and null_res) + conde_branches.append( + [eq(a_car, b_car) for a_car, b_car in zip(cars, cars[1:])] + + [ + Zzz( + map_anyo, + goal, + *cdrs, + null_type=null_type, + null_res=null_res, + _first=False, + _any_succeed=_any_succeed, + **kwargs, + ), + ] + ) + else: + nullo_condition = not _first or null_res + + if use_ground_ordering: + args_ord = tuple(var() for t in args) + ground_order_goal = ground_order_seqs(args, args_ord) + else: + args_ord = args + ground_order_goal = succeed + + return conde( + [nullo(*args, default_ConsNull=null_type) if nullo_condition else fail], + [ground_order_goal] + + [conso(car, cdr, arg) for car, cdr, arg in zip(cars, cdrs, args_ord)] + + [conde(*conde_branches)], ) @@ -90,15 +120,19 @@ def eq_length(u, v, default_ConsNull=list): return mapo(vararg_success, u, v, null_type=default_ConsNull) -def reduceo(relation, in_term, out_term, *args, **kwargs): - """Relate a term and the fixed-point of that term under a given relation. +def reduceo(goal, in_term, out_term, *args, **kwargs): + """Construct a goal that yields the fixed-point of another goal. + + It simply tries to order the implicit `conde` recursion branches so that they + produce the fixed-point value first. All goals that follow are the reductions + leading up to the fixed-point. - This includes the "identity" relation. + FYI: The results will include `eq(in_term, out_term)`. """ def reduceo_goal(s): - nonlocal in_term, out_term, relation, args, kwargs + nonlocal in_term, out_term, goal, args, kwargs in_term_rf, out_term_rf = reify((in_term, out_term), s) @@ -106,19 +140,19 @@ def reduceo_goal(s): term_rdcd = var() # Are we working "backward" and (potentially) "expanding" a graph - # (e.g. when the relation is a reduction rule)? + # (e.g. when the goal is a reduction rule)? is_expanding = isvar(in_term_rf) - # One application of the relation assigned to `term_rdcd` - single_apply_g = relation(in_term_rf, term_rdcd, *args, **kwargs) + # One application of the goal assigned to `term_rdcd` + single_apply_g = goal(in_term_rf, term_rdcd, *args, **kwargs) # Assign/equate (unify, really) the result of a single application to # the "output" term. single_res_g = eq(term_rdcd, out_term_rf) - # Recurse into applications of the relation (well, produce a goal that + # Recurse into applications of the goal (well, produce a goal that # will do that) - another_apply_g = reduceo(relation, term_rdcd, out_term_rf, *args, **kwargs) + another_apply_g = reduceo(goal, term_rdcd, out_term_rf, *args, **kwargs) # We want the fixed-point value to show up in the stream output # *first*, but that requires some checks. @@ -147,69 +181,52 @@ def reduceo_goal(s): def walko( goal, - graph_in, - graph_out, - rator_goal=None, + *terms, + pre_process_fn=None, null_type=etuple, map_rel=partial(map_anyo, null_res=True), ): - """Apply a binary relation between all nodes in two graphs. - - When `rator_goal` is used, the graphs are treated as term graphs, and the - multi-functions `rator`, `rands`, and `apply` are used to walk the graphs. - Otherwise, the graphs must be iterable according to `map_anyo`. + """Apply a goal between all nodes in a set of terms. Parameters ---------- goal: callable - A goal that is applied to all terms in the graph. - graph_in: object - The graph for which the left-hand side of a binary relation holds. - graph_out: object - The graph for which the right-hand side of a binary relation holds. - rator_goal: callable (default None) - A goal that is applied to the rators of a graph. When specified, - `goal` is only applied to rands and it must succeed along with the - rator goal in order to descend into sub-terms. + A goal that is applied to all terms and their sub-terms. + *terms: Sequence of objects + The terms to be walked. + pre_process_fn: callable (default None) + A goal with a signature of the form `(old_terms, new_terms)`, where + each argument is a list of corresponding terms. + This goal can be used to transform terms before walking them. null_type: type - The collection type used when it is not fully determined by the graph + The collection type used when it is not fully determined by the `terms` arguments. map_rel: callable - The map relation used to apply `goal` to a sub-graph. + The map goal used to apply `goal` to corresponding sub-terms. """ - def walko_goal(s): - - nonlocal goal, rator_goal, graph_in, graph_out, null_type, map_rel - - graph_in_rf, graph_out_rf = reify((graph_in, graph_out), s) - - rator_in, rands_in, rator_out, rands_out = var(), var(), var(), var() - - _walko = partial( - walko, goal, rator_goal=rator_goal, null_type=null_type, map_rel=map_rel - ) - - g = conde( - # TODO: Use `Zzz`, if needed. - [ - goal(graph_in_rf, graph_out_rf), - ], - [ - lall( - applyo(rator_in, rands_in, graph_in_rf), - applyo(rator_out, rands_out, graph_out_rf), - rator_goal(rator_in, rator_out), - map_rel(_walko, rands_in, rands_out, null_type=null_type), - ) - if rator_goal is not None - else map_rel(_walko, graph_in_rf, graph_out_rf, null_type=null_type), - ], - ) - - yield from g(s) + if pre_process_fn is not None: + terms_pp = tuple(var() for t in terms) + pre_process_goal = pre_process_fn(*(terms + terms_pp)) + else: + terms_pp = terms + pre_process_goal = succeed + + _walko = partial( + walko, + goal, + pre_process_fn=pre_process_fn, + null_type=null_type, + map_rel=map_rel, + ) - return walko_goal + return lall( + pre_process_goal, + conde( + [Zzz(goal, *terms_pp)], + [Zzz(map_rel, _walko, *terms_pp, null_type=null_type)], + ), + ) def term_walko( @@ -220,7 +237,7 @@ def term_walko( null_type=etuple, no_ident=False, format_step=None, - **kwargs + **kwargs, ): """Construct a goal for walking a term graph. @@ -233,13 +250,11 @@ def term_walko( should always fail! """ - def single_step(s, t): - u, v = var(), var() + def single_step(u, v): u_rator, u_rands = var(), var() v_rands = var() return lall( - ground_order((s, t), (u, v)), applyo(u_rator, u_rands, u), applyo(u_rator, v_rands, v), rator_goal(u_rator), @@ -250,13 +265,11 @@ def single_step(s, t): Zzz(rands_goal, u_rands, v_rands, u_rator, **kwargs), ) - def term_walko_step(s, t): + def term_walko_step(u, v): nonlocal rator_goal, rands_goal, null_type - u, v = var(), var() z, w = var(), var() return lall( - ground_order((s, t), (u, v)), format_step(u, w) if format_step is not None else eq(u, w), conde( [ diff --git a/kanren/term.py b/kanren/term.py index 6c6a094..7cc0e3e 100644 --- a/kanren/term.py +++ b/kanren/term.py @@ -1,16 +1,46 @@ -from collections.abc import Mapping, Sequence +from abc import ABCMeta +from collections.abc import Callable, Mapping +from operator import length_hint -from cons.core import ConsError, cons +from cons.core import ConsError, ProperSequence, cons +from etuples import apply from etuples import apply as term from etuples import rands as arguments from etuples import rator as operator +from multipledispatch import dispatch from unification.core import _reify, _unify, construction_sentinel, reify from unification.variable import isvar -from .core import eq, lall +from .core import eq, lall, shallow_ground_order_key from .goals import conso +class TermMetaType(ABCMeta): + """A meta type that can be used to check for `operator`/`arguments` support.""" + + def __instancecheck__(self, o): + o_type = type(o) + if ( + isinstance(o, ProperSequence) + or any( + issubclass(o_type, k) for k in operator.funcs.keys() if k[0] != object + ) + ) and length_hint(o, 1) > 0: + return True + return False + + def __subclasscheck__(self, o_type): + if issubclass(o_type, ProperSequence) or any( + issubclass(o_type, k) for k in operator.funcs.keys() if k[0] != object + ): + return True + return False + + +class TermType(metaclass=TermMetaType): + pass + + def applyo(o_rator, o_rands, obj): """Construct a goal that relates an object to the application of its (ope)rator to its (ope)rands. @@ -55,14 +85,19 @@ def applyo_goal(S): return applyo_goal -@term.register(object, Sequence) -def term_Sequence(rator, rands): +@dispatch(object, ProperSequence) +def term(rator, rands): # Overwrite the default `apply` dispatch function and make it preserve # types res = cons(rator, rands) return res +@term.register(Callable, ProperSequence) +def term_ExpressionTuple(rator, rands): + return apply(rator, rands) + + def unifiable_with_term(cls): _reify.add((cls, Mapping), reify_term) _unify.add((cls, cls, Mapping), unify_term) @@ -84,3 +119,13 @@ def unify_term(u, v, s): if s is not False: s = yield _unify(u_args, v_args, s) yield s + + +@shallow_ground_order_key.register(Mapping, TermType) +def shallow_ground_order_key_TermType(S, x): + if length_hint(x, 1) > 0: + return shallow_ground_order_key.dispatch(type(S), object)( + S, cons(operator(x), arguments(x)) + ) + else: + return shallow_ground_order_key.dispatch(type(S), object)(S, x) diff --git a/requirements.txt b/requirements.txt index 87db56e..995ffd8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ pydocstyle>=3.0.0 pytest>=5.0.0 pytest-cov>=2.6.1 pytest-html>=1.20.0 +pytest-timeout pylint>=2.3.1 black>=19.3b0; platform.python_implementation!='PyPy' diff-cover diff --git a/setup.py b/setup.py index c7a77c6..06ee62d 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ packages=["kanren"], install_requires=[ "toolz", - "cons >= 0.4.0", + "cons >= 0.4.2", "multipledispatch", "etuples >= 0.3.1", "logical-unification >= 0.4.1", diff --git a/tests/test_assoccomm.py b/tests/test_assoccomm.py index 3795701..8b2c39b 100644 --- a/tests/test_assoccomm.py +++ b/tests/test_assoccomm.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from itertools import chain import pytest from cons import cons @@ -7,7 +7,7 @@ from kanren.assoccomm import ( assoc_args, - assoc_flatten, + assoc_flatteno, associative, commutative, eq_assoc, @@ -15,63 +15,11 @@ eq_assoccomm, eq_comm, flatten_assoc_args, + partitions, ) from kanren.core import run from kanren.facts import fact -from kanren.term import arguments, operator, term - - -class Node(object): - def __init__(self, op, args): - self.op = op - self.args = args - - def __eq__(self, other): - return ( - type(self) == type(other) - and self.op == other.op - and self.args == other.args - ) - - def __hash__(self): - return hash((type(self), self.op, self.args)) - - def __str__(self): - return "%s(%s)" % (self.op.name, ", ".join(map(str, self.args))) - - __repr__ = __str__ - - -class Operator(object): - def __init__(self, name): - self.name = name - - -Add = Operator("add") -Mul = Operator("mul") - - -def add(*args): - return Node(Add, args) - - -def mul(*args): - return Node(Mul, args) - - -@term.register(Operator, Sequence) -def term_Operator(op, args): - return Node(op, args) - - -@arguments.register(Node) -def arguments_Node(n): - return n.args - - -@operator.register(Node) -def operator_Node(n): - return n.op +from tests.utils import Add def results(g, s=None): @@ -81,8 +29,6 @@ def results(g, s=None): def test_eq_comm(): - x, y, z = var(), var(), var() - commutative.facts.clear() commutative.index.clear() @@ -90,12 +36,14 @@ def test_eq_comm(): fact(commutative, comm_op) + x, y, z = var(), var(), var() + assert run(0, True, eq_comm(1, 1)) == (True,) assert run(0, True, eq_comm((comm_op, 1, 2, 3), (comm_op, 1, 2, 3))) == (True,) assert run(0, True, eq_comm((comm_op, 3, 2, 1), (comm_op, 1, 2, 3))) == (True,) - assert run(0, y, eq_comm((comm_op, 3, y, 1), (comm_op, 1, 2, 3))) == (2,) - assert run(0, (x, y), eq_comm((comm_op, x, y, 1), (comm_op, 1, 2, 3))) == ( + assert run(0, y, eq_comm((comm_op, 1, 2, 3), (comm_op, 3, y, 1))) == (2,) + assert run(0, (x, y), eq_comm((comm_op, 1, 2, 3), (comm_op, x, y, 1))) == ( (2, 3), (3, 2), ) @@ -111,20 +59,6 @@ def test_eq_comm(): assert not run(0, True, eq_comm((comm_op, 1, 2, 1), (comm_op, 1, 2, 3))) assert not run(0, True, eq_comm(("op", 1, 2, 3), (comm_op, 1, 2, 3))) - # Test for variable args - res = run(4, (x, y), eq_comm(x, y)) - exp_res_form = ( - (etuple(comm_op, x, y), etuple(comm_op, y, x)), - (x, y), - (etuple(etuple(comm_op, x, y)), etuple(etuple(comm_op, y, x))), - (etuple(comm_op, x, y, z), etuple(comm_op, x, z, y)), - ) - - for a, b in zip(res, exp_res_form): - s = unify(a, b) - assert s is not False - assert all(isvar(i) for i in reify((x, y, z), s)) - # Make sure it can unify single elements assert (3,) == run(0, x, eq_comm((comm_op, 1, 2, 3), (comm_op, 2, x, 1))) @@ -141,9 +75,9 @@ def test_eq_comm(): assert expected_res == set( run(0, (x, y, z), eq_comm((comm_op, 1, 2, 3), (comm_op, x, y, z))) ) - assert expected_res == set( - run(0, (x, y, z), eq_comm((comm_op, x, y, z), (comm_op, 1, 2, 3))) - ) + # assert expected_res == set( + # run(0, (x, y, z), eq_comm((comm_op, x, y, z), (comm_op, 1, 2, 3))) + # ) assert expected_res == set( run( 0, @@ -152,37 +86,80 @@ def test_eq_comm(): ) ) - e1 = (comm_op, (comm_op, 1, x), y) - e2 = (comm_op, 2, (comm_op, 3, 1)) + e1 = (comm_op, 2, (comm_op, 3, 1)) + e2 = (comm_op, (comm_op, 1, x), y) assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) e1 = ((comm_op, 3, 1),) e2 = ((comm_op, 1, x),) - assert run(0, x, eq_comm(e1, e2)) == (3,) e1 = (2, (comm_op, 3, 1)) e2 = (y, (comm_op, 1, x)) - assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) - e1 = (comm_op, (comm_op, 1, x), y) - e2 = (comm_op, 2, (comm_op, 3, 1)) - assert run(0, (x, y), eq_comm(e1, e2)) == ((3, 2),) +def test_eq_comm_all_variations(): + commutative.facts.clear() + commutative.index.clear() + + comm_op = "comm_op" + + fact(commutative, comm_op) + + expected_res = { + (comm_op, 1, (comm_op, 2, (comm_op, 3, 4))), + (comm_op, 1, (comm_op, 2, (comm_op, 4, 3))), + (comm_op, 1, (comm_op, (comm_op, 3, 4), 2)), + (comm_op, 1, (comm_op, (comm_op, 4, 3), 2)), + (comm_op, (comm_op, 2, (comm_op, 3, 4)), 1), + (comm_op, (comm_op, 2, (comm_op, 4, 3)), 1), + (comm_op, (comm_op, (comm_op, 3, 4), 2), 1), + (comm_op, (comm_op, (comm_op, 4, 3), 2), 1), + } + + x = var() + + for s in expected_res: + res = run(0, x, eq_comm(s, x)) + assert sorted(res, key=str) == sorted(expected_res, key=str) + + +def test_eq_comm_unground(): + commutative.facts.clear() + commutative.index.clear() + + comm_op = "comm_op" + + fact(commutative, comm_op) + + x, y, z = var(), var(), var() + # Test for variable args + res = run(4, (x, y), eq_comm(x, y)) + exp_res_form = ( + (etuple(comm_op, x, y), etuple(comm_op, y, x)), + (x, y), + (etuple(etuple(comm_op, x, y)), etuple(etuple(comm_op, y, x))), + (etuple(comm_op, x, y, z), etuple(comm_op, x, z, y)), + ) + + for a, b in zip(res, exp_res_form): + s = unify(a, b) + assert s is not False + assert all(isvar(i) for i in reify((x, y, z), s)) @pytest.mark.xfail(reason="`applyo`/`buildo` needs to be a constraint.", strict=True) def test_eq_comm_object(): - x = var("x") + x = var() fact(commutative, Add) fact(associative, Add) - assert run(0, x, eq_comm(add(1, 2, 3), add(3, 1, x))) == (2,) - assert set(run(0, x, eq_comm(add(1, 2), x))) == set((add(1, 2), add(2, 1))) - assert set(run(0, x, eq_assoccomm(add(1, 2, 3), add(1, x)))) == set( - (add(2, 3), add(3, 2)) + assert run(0, x, eq_comm(Add(1, 2, 3), Add(3, 1, x))) == (2,) + assert set(run(0, x, eq_comm(Add(1, 2), x))) == set((Add(1, 2), Add(2, 1))) + assert set(run(0, x, eq_assoccomm(Add(1, 2, 3), Add(1, x)))) == set( + (Add(2, 3), Add(3, 2)) ) @@ -198,39 +175,145 @@ def op_pred(x): res = list( flatten_assoc_args( - op_pred, [[1, 2, op], 3, [op, 4, [op, [op]]], [op, 5], 6, op, 7] + op_pred, + [[1, 2, op], 3, [op, 4, [op, [op]]], [op, 5], 6, op, 7], + shallow=False, ) ) exp_res = [[1, 2, op], 3, 4, [op], 5, 6, op, 7] assert res == exp_res - -def test_assoc_args(): - op = "add" - - def op_pred(x): - return x == op - - assert tuple(assoc_args(op, (1, 2, 3), 2)) == ( - ((op, 1, 2), 3), - (1, (op, 2, 3)), - ) - assert tuple(assoc_args(op, [1, 2, 3], 2)) == ( - [[op, 1, 2], 3], - [1, [op, 2, 3]], + exa_col = (1, 2, ("b", 3, ("a", 4, 5)), ("c", 6, 7), ("a", ("a", 8, 9), 10)) + assert list(flatten_assoc_args(lambda x: x == "a", exa_col, shallow=False)) == [ + 1, + 2, + ("b", 3, ("a", 4, 5)), + ("c", 6, 7), + 8, + 9, + 10, + ] + + assert list(flatten_assoc_args(lambda x: x == "a", exa_col, shallow=True)) == [ + 1, + 2, + ("b", 3, ("a", 4, 5)), + ("c", 6, 7), + ("a", 8, 9), + 10, + ] + + +def test_partitions(): + + assert list(partitions(("a",), 2, 2)) == [] + assert list(partitions(("a", "b"), 2, 2)) == [] + assert list(partitions(("a", "b"), 2, 1)) == [(("a",), ("b",))] + assert list(partitions(("a", "b"), 2, 1, part_fn=lambda x: ("op",) + x)) == [ + (("op", "a"), ("op", "b")) + ] + + exa_col = tuple("abcdefg") + + expected_res = [ + (("a", "b"), ("c", "d", "e", "f", "g")), + (("a", "b", "c"), ("d", "e", "f", "g")), + (("a", "b", "c", "d"), ("e", "f", "g")), + (("a", "b", "c", "d", "e"), ("f", "g")), + ] + + assert list(partitions(exa_col, 2, 2)) == expected_res + + expected_res = [ + (("a", "b"), ("c", "d"), ("e", "f", "g")), + (("a", "b"), ("c", "d", "e"), ("f", "g")), + (("a", "b", "c"), ("d", "e"), ("f", "g")), + ] + + assert list(partitions(exa_col, 3, 2)) == expected_res + + expected_res = sorted( + chain.from_iterable( + [partitions(exa_col, i, 2) for i in range(2, len(exa_col) + 1)] + ) ) - assert tuple(assoc_args(op, (1, 2, 3), 1)) == ( - ((op, 1), 2, 3), - (1, (op, 2), 3), - (1, 2, (op, 3)), + assert sorted(partitions(exa_col, None, 2)) == expected_res + + res = list( + partitions( + tuple(range(1, 5)), + None, + 1, + part_fn=lambda x: x[0] if len(x) == 1 else ("op",) + x, + ) ) - assert tuple(assoc_args(op, (1, 2, 3), 3)) == ((1, 2, 3),) + assert res == [ + (1, ("op", 2, 3, 4)), + (1, 2, ("op", 3, 4)), + (1, 2, 3, 4), + (1, ("op", 2, 3), 4), + (("op", 1, 2), ("op", 3, 4)), + (("op", 1, 2), 3, 4), + (("op", 1, 2, 3), 4), + ] - f_rands = flatten_assoc_args(op_pred, (1, (op, 2, 3))) - assert tuple(assoc_args(op, f_rands, 2, ctor=tuple)) == ( - ((op, 1, 2), 3), - (1, (op, 2, 3)), + res = list( + partitions( + tuple(range(1, 5)), + 2, + 1, + part_fn=lambda x: x[0] if len(x) == 1 else ("op",) + x, + ) ) + assert res == [ + (1, ("op", 2, 3, 4)), + (("op", 1, 2), ("op", 3, 4)), + (("op", 1, 2, 3), 4), + ] + + +def test_assoc_args(): + + res = list(assoc_args("op", tuple(range(1, 5)), None)) + assert res == [ + (1, ("op", 2, 3, 4)), + (1, 2, ("op", 3, 4)), + (1, 2, 3, 4), + (1, ("op", 2, 3), 4), + (("op", 1, 2), ("op", 3, 4)), + (("op", 1, 2), 3, 4), + (("op", 1, 2, 3), 4), + ] + + res = list(assoc_args("op", tuple(range(1, 5)), None, ctor=list)) + assert res == [ + [1, ["op", 2, 3, 4]], + [1, 2, ["op", 3, 4]], + [1, 2, 3, 4], + [1, ["op", 2, 3], 4], + [["op", 1, 2], ["op", 3, 4]], + [["op", 1, 2], 3, 4], + [["op", 1, 2, 3], 4], + ] + + res = list(assoc_args("op", tuple(range(1, 5)), 2)) + assert res == [ + (1, ("op", 2, 3, 4)), + (("op", 1, 2), ("op", 3, 4)), + (("op", 1, 2, 3), 4), + ] + + res = list(assoc_args("op", (1, 2), 1)) + assert res == [(1, 2)] + + res = list(assoc_args("op", (1, 2, 3), 4)) + assert res == [(1, 2, 3)] + + res = list(assoc_args("op", (1, 2, 3), 3)) + assert res == [(1, 2, 3)] + + res = list(assoc_args("op", [1, 2, 3], 3, ctor=tuple)) + assert res == [(1, 2, 3)] def test_eq_assoc_args(): @@ -273,18 +356,18 @@ def test_eq_assoc_args(): assert run(0, True, eq_assoc_args(assoc_op, (1, 1), ("other_op", 1, 1))) == () assert run(0, x, eq_assoc_args(assoc_op, (1, 2, 3), x, n=2)) == ( - ((assoc_op, 1, 2), 3), (1, (assoc_op, 2, 3)), + ((assoc_op, 1, 2), 3), ) assert run(0, x, eq_assoc_args(assoc_op, x, (1, 2, 3), n=2)) == ( - ((assoc_op, 1, 2), 3), (1, (assoc_op, 2, 3)), + ((assoc_op, 1, 2), 3), ) assert run(0, x, eq_assoc_args(assoc_op, (1, 2, 3), x)) == ( - ((assoc_op, 1, 2), 3), (1, (assoc_op, 2, 3)), (1, 2, 3), + ((assoc_op, 1, 2), 3), ) assert () not in run(0, x, eq_assoc_args(assoc_op, (), x, no_ident=True)) @@ -324,11 +407,11 @@ def test_eq_assoc_args(): def test_eq_assoc(): - assoc_op = "assoc_op" - associative.index.clear() associative.facts.clear() + assoc_op = "assoc_op" + fact(associative, assoc_op) assert run(0, True, eq_assoc(1, 1)) == (True,) @@ -343,61 +426,118 @@ def test_eq_assoc(): o = "op" assert not run(0, True, eq_assoc((o, 1, 2, 3), (o, (o, 1, 2), 3))) - x = var() + x, y = var(), var() + res = run(0, x, eq_assoc((assoc_op, 1, 2, 3), x, n=2)) assert res == ( - (assoc_op, (assoc_op, 1, 2), 3), - (assoc_op, 1, 2, 3), (assoc_op, 1, (assoc_op, 2, 3)), + (assoc_op, 1, 2, 3), + (assoc_op, (assoc_op, 1, 2), 3), ) - res = run(0, x, eq_assoc(x, (assoc_op, 1, 2, 3), n=2)) - assert res == ( - (assoc_op, (assoc_op, 1, 2), 3), - (assoc_op, 1, 2, 3), - (assoc_op, 1, (assoc_op, 2, 3)), + # Make sure it works with `cons` + res = run(0, (x, y), eq_assoc((assoc_op, 1, 2, 3), cons(x, y))) + assert sorted(res, key=str) == sorted( + [ + (assoc_op, (1, (assoc_op, 2, 3))), + (assoc_op, (1, 2, 3)), + (assoc_op, ((assoc_op, 1, 2), 3)), + ], + key=str, ) - y, z = var(), var() + # XXX: Don't use a predicate that can never succeed, e.g. + # associative_2 = Relation("associative_2") + # run(1, (x, y), eq_assoc(cons(x, y), (x, z), op_predicate=associative_2)) + + # Nested expressions should work now + expr1 = (assoc_op, (assoc_op, 1, 2), 3, 4, 5, 6) + expr2 = (assoc_op, 1, 2, (assoc_op, x, 5, 6)) + assert run(0, x, eq_assoc(expr1, expr2, n=2)) == ((assoc_op, 3, 4),) + + +@pytest.mark.xfail(strict=False) +def test_eq_assoc_cons(): + associative.index.clear() + associative.facts.clear() + + assoc_op = "assoc_op" + + fact(associative, assoc_op) + + x, y, z = var(), var(), var() + + res = run(1, (x, y), eq_assoc(cons(x, y), (x, z, 2, 3))) + assert res == ((assoc_op, (z, (assoc_op, 2, 3))),) + + +@pytest.mark.xfail(strict=False) +def test_eq_assoc_all_variations(): + + associative.index.clear() + associative.facts.clear() + + assoc_op = "assoc_op" + + fact(associative, assoc_op) + + x = var() + # Normalized, our results are left-associative, i.e. + # (assoc_op, (assoc_op, (assoc_op, 1, 2), 3), 4) == (assoc_op, 1, 2, 3, 4) + expected_res = { + (assoc_op, (assoc_op, (assoc_op, 1, 2), 3), 4), # Missing + (assoc_op, (assoc_op, 1, (assoc_op, 2, 3)), 4), # Missing + (assoc_op, (assoc_op, 1, 2), (assoc_op, 3, 4)), + (assoc_op, (assoc_op, 1, 2), 3, 4), + (assoc_op, (assoc_op, 1, 2, 3), 4), + (assoc_op, 1, (assoc_op, (assoc_op, 2, 3), 4)), # Missing + (assoc_op, 1, (assoc_op, 2, (assoc_op, 3, 4))), # Missing + (assoc_op, 1, (assoc_op, 2, 3), 4), + (assoc_op, 1, (assoc_op, 2, 3, 4)), + (assoc_op, 1, 2, (assoc_op, 3, 4)), + (assoc_op, 1, 2, 3, 4), + } + + for s in expected_res: + res = run(0, x, eq_assoc(s, x)) + assert sorted(res, key=str) == sorted(expected_res, key=str) + + +def test_eq_assoc_unground(): + + associative.index.clear() + associative.facts.clear() + + assoc_op = "assoc_op" + + fact(associative, assoc_op) + + x, y = var(), var() + xx, yy, zz = var(), var(), var() # Check results when both arguments are variables res = run(3, (x, y), eq_assoc(x, y)) exp_res_form = ( - (etuple(assoc_op, x, y, z), etuple(assoc_op, etuple(assoc_op, x, y), z)), - (x, y), + (etuple(assoc_op, xx, yy, zz), etuple(assoc_op, xx, etuple(assoc_op, yy, zz))), + (xx, yy), ( - etuple(etuple(assoc_op, x, y, z)), - etuple(etuple(assoc_op, etuple(assoc_op, x, y), z)), + etuple(etuple(assoc_op, xx, yy, zz)), + etuple(etuple(assoc_op, xx, etuple(assoc_op, yy, zz))), ), ) for a, b in zip(res, exp_res_form): s = unify(a, b) assert s is not False, (a, b) - assert all(isvar(i) for i in reify((x, y, z), s)) - - # Make sure it works with `cons` - res = run(0, (x, y), eq_assoc(cons(x, y), (assoc_op, 1, 2, 3))) - assert res == ( - (assoc_op, ((assoc_op, 1, 2), 3)), - (assoc_op, (1, 2, 3)), - (assoc_op, (1, (assoc_op, 2, 3))), - ) - - res = run(1, (x, y), eq_assoc(cons(x, y), (x, z, 2, 3))) - assert res == ((assoc_op, ((assoc_op, z, 2), 3)),) - - # Don't use a predicate that can never succeed, e.g. - # associative_2 = Relation("associative_2") - # run(1, (x, y), eq_assoc(cons(x, y), (x, z), op_predicate=associative_2)) + assert all(isvar(i) for i in reify((xx, yy, zz), s)) - # Nested expressions should work now - expr1 = (assoc_op, 1, 2, (assoc_op, x, 5, 6)) - expr2 = (assoc_op, (assoc_op, 1, 2), 3, 4, 5, 6) - assert run(0, x, eq_assoc(expr1, expr2, n=2)) == ((assoc_op, 3, 4),) +def test_assoc_flatteno(): -def test_assoc_flatten(): + commutative.index.clear() + commutative.facts.clear() + associative.index.clear() + associative.facts.clear() add = "add" mul = "mul" @@ -406,12 +546,13 @@ def test_assoc_flatten(): fact(associative, add) fact(commutative, mul) fact(associative, mul) + fact(associative, Add) assert ( run( 0, True, - assoc_flatten( + assoc_flatteno( (mul, 1, (add, 2, 3), (mul, 4, 5)), (mul, 1, (add, 2, 3), 4, 5) ), ) @@ -419,45 +560,120 @@ def test_assoc_flatten(): ) x = var() - assert ( - run( - 0, - x, - assoc_flatten((mul, 1, (add, 2, 3), (mul, 4, 5)), x), - ) - == ((mul, 1, (add, 2, 3), 4, 5),) + assert run(0, x, assoc_flatteno((mul, 1, (add, 2, 3), (mul, 4, 5)), x)) == ( + (mul, 1, (add, 2, 3), 4, 5), ) assert ( run( 0, True, - assoc_flatten( + assoc_flatteno( ("op", 1, (add, 2, 3), (mul, 4, 5)), ("op", 1, (add, 2, 3), (mul, 4, 5)) ), ) == (True,) ) - assert run(0, x, assoc_flatten(("op", 1, (add, 2, 3), (mul, 4, 5)), x)) == ( + assert run(0, x, assoc_flatteno(("op", 1, (add, 2, 3), (mul, 4, 5)), x)) == ( ("op", 1, (add, 2, 3), (mul, 4, 5)), ) + assert run( + 0, True, assoc_flatteno((add, 1, (add, 2, 3), (mul, 4, 5)), x, no_ident=True) + ) == (True,) + + assert ( + run( + 0, + True, + assoc_flatteno((add, 1, (mul, 2, 3), (mul, 4, 5)), x, no_ident=True), + ) + == () + ) + + assert run(0, True, assoc_flatteno((add, 1, 2, 3), x)) == (True,) + assert run(0, True, assoc_flatteno(Add(1, 2, 3), x)) == (True,) + assert run(0, True, assoc_flatteno((add, 1, 2, 3), x, no_ident=True)) == () + + res = run(0, x, assoc_flatteno(x, (add, 1, 2, 3), no_ident=True)) + assert sorted(res, key=str) == sorted( + [(add,) + r for r in assoc_args(add, (1, 2, 3)) if r != (add, 1, 2, 3)], key=str + ) + + res = run(0, x, assoc_flatteno(x, (add, 1, 2, 3), no_ident=True)) + assert sorted(res, key=str) == sorted( + [(add,) + r for r in assoc_args(add, (1, 2, 3)) if r != (add, 1, 2, 3)], key=str + ) + + +def test_assoc_flatteno_unground(): + + commutative.index.clear() + commutative.facts.clear() + associative.index.clear() + associative.facts.clear() + + add = "add" + mul = "mul" + + fact(commutative, add) + fact(associative, add) + fact(commutative, mul) + fact(associative, mul) -def test_eq_assoccomm(): x, y = var(), var() - ac = "commassoc_op" + xx, yy, zz = var(), var(), var() + op_lv = var() + exp_res_form = ( + (xx, yy), + (etuple(op_lv, xx, yy), etuple(op_lv, xx, yy)), + (etuple(op_lv, xx, yy), etuple(op_lv, xx, yy)), + (etuple(op_lv, xx, etuple(op_lv, yy, zz)), etuple(op_lv, xx, yy, zz)), + ) + + res = run(4, (x, y), assoc_flatteno(x, y)) + for a, b in zip(res, exp_res_form): + s = unify(a, b) + assert s is not False, (a, b) + assert op_lv not in s or (s[op_lv],) in associative.facts + assert all(isvar(i) for i in reify((xx, yy, zz), s)) + + ww = var() + exp_res_form = ( + (etuple(op_lv, xx, etuple(op_lv, yy, zz)), etuple(op_lv, xx, yy, zz)), + (etuple(op_lv, xx, etuple(op_lv, yy, zz)), etuple(op_lv, xx, yy, zz)), + (etuple(op_lv, xx, etuple(op_lv, yy, zz, ww)), etuple(op_lv, xx, yy, zz, ww)), + ) + res = run(3, (x, y), assoc_flatteno(x, y, no_ident=True)) + + for a, b in zip(res, exp_res_form): + assert a[0] != a[1] + s = unify(a, b) + assert s is not False, (a, b) + assert op_lv not in s or (s[op_lv],) in associative.facts + assert all(isvar(i) for i in reify((xx, yy, zz, ww), s)) + + +def test_eq_assoccomm(): + + associative.index.clear() + associative.facts.clear() commutative.index.clear() commutative.facts.clear() + ac = "commassoc_op" + fact(commutative, ac) fact(associative, ac) + x, y = var(), var() + assert run(0, True, eq_assoccomm(1, 1)) == (True,) assert run(0, True, eq_assoccomm((1,), (1,))) == (True,) - assert run(0, True, eq_assoccomm(x, (1,))) == (True,) + # assert run(0, True, eq_assoccomm(x, (1,))) == (True,) assert run(0, True, eq_assoccomm((1,), x)) == (True,) # Assoc only @@ -473,7 +689,38 @@ def test_eq_assoccomm(): True, ) - exp_res = set( + assert run(0, (x, y), eq_assoccomm((ac, 2, (ac, 3, 1)), (ac, (ac, 1, x), y))) == ( + (2, 3), + (3, 2), + ) + # assert run(0, (x, y), eq_assoccomm((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) == ( + # (2, 3), + # (3, 2), + # ) + + assert run(0, True, eq_assoccomm((ac, (ac, 1, 2), 3), (ac, 1, 2, 3))) == (True,) + assert run(0, True, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, 2, 3))) == (True,) + assert run(0, True, eq_assoccomm((ac, 1, 1), ("other_op", 1, 1))) == () + + assert run(0, x, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, x, 3))) == (2,) + + +def test_eq_assoccomm_all_variations(): + + associative.index.clear() + associative.facts.clear() + commutative.index.clear() + commutative.facts.clear() + + ac = "commassoc_op" + + fact(commutative, ac) + fact(associative, ac) + + x = var() + + # TODO: Use four arguments to see real associative variation. + expected_res = set( ( (ac, 1, 3, 2), (ac, 1, 2, 3), @@ -495,22 +742,28 @@ def test_eq_assoccomm(): (ac, (ac, 2, 1), 3), ) ) - assert set(run(0, x, eq_assoccomm((ac, 1, (ac, 2, 3)), x))) == exp_res - assert set(run(0, x, eq_assoccomm((ac, 1, 3, 2), x))) == exp_res - assert set(run(0, x, eq_assoccomm((ac, 2, (ac, 3, 1)), x))) == exp_res - # LHS variations - assert set(run(0, x, eq_assoccomm(x, (ac, 1, (ac, 2, 3))))) == exp_res - assert run(0, (x, y), eq_assoccomm((ac, (ac, 1, x), y), (ac, 2, (ac, 3, 1)))) == ( - (2, 3), - (3, 2), - ) + for s in expected_res: + res = run(0, x, eq_assoccomm(s, x)) + # TODO FIXME: Avoid the extra identity result + res = list(res) + res.remove(s) + assert sorted(res, key=str) == sorted(expected_res, key=str) - assert run(0, True, eq_assoccomm((ac, (ac, 1, 2), 3), (ac, 1, 2, 3))) == (True,) - assert run(0, True, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, 2, 3))) == (True,) - assert run(0, True, eq_assoccomm((ac, 1, 1), ("other_op", 1, 1))) == () - assert run(0, x, eq_assoccomm((ac, 3, (ac, 1, 2)), (ac, 1, x, 3))) == (2,) +def test_eq_assoccomm_unground(): + + associative.index.clear() + associative.facts.clear() + commutative.index.clear() + commutative.facts.clear() + + ac = "commassoc_op" + + fact(commutative, ac) + fact(associative, ac) + + x, y = var(), var() # Both arguments unground op_lv = var() @@ -520,36 +773,33 @@ def test_eq_assoccomm(): (etuple(op_lv, x, y), etuple(op_lv, y, x)), (y, y), ( - etuple(etuple(op_lv, x, y)), - etuple(etuple(op_lv, y, x)), - ), - ( - etuple(op_lv, x, y, z), - etuple(op_lv, etuple(op_lv, x, y), z), + etuple(op_lv, x, y), + etuple(op_lv, y, x), ), + (etuple(op_lv, x, etuple(op_lv, y, z)), etuple(op_lv, x, etuple(op_lv, z, y))), ) for a, b in zip(res, exp_res_form): s = unify(a, b) + assert s is not False, (a, b) assert ( op_lv not in s or (s[op_lv],) in associative.facts or (s[op_lv],) in commutative.facts ) - assert s is not False, (a, b) assert all(isvar(i) for i in reify((x, y, z), s)) -def test_assoccomm_algebra(): - - add = "add" - mul = "mul" +def test_eq_assoccomm_algebra(): commutative.index.clear() commutative.facts.clear() associative.index.clear() associative.facts.clear() + add = "add" + mul = "mul" + fact(commutative, add) fact(associative, add) fact(commutative, mul) @@ -557,13 +807,13 @@ def test_assoccomm_algebra(): x, y = var(), var() - pattern = (mul, (add, 1, x), y) # (1 + x) * y - expr = (mul, 2, (add, 3, 1)) # 2 * (3 + 1) + pattern = (mul, 2, (add, 3, 1)) # 2 * (3 + 1) + expr = (mul, (add, 1, x), y) # (1 + x) * y assert run(0, (x, y), eq_assoccomm(pattern, expr)) == ((3, 2),) -def test_assoccomm_objects(): +def test_eq_assoccomm_objects(): commutative.index.clear() commutative.facts.clear() @@ -575,6 +825,47 @@ def test_assoccomm_objects(): x = var() - assert run(0, True, eq_assoccomm(add(1, 2, 3), add(3, 1, 2))) == (True,) - assert run(0, x, eq_assoccomm(add(1, 2, 3), add(1, 2, x))) == (3,) - assert run(0, x, eq_assoccomm(add(1, 2, 3), add(x, 2, 1))) == (3,) + assert run(0, True, eq_assoccomm(Add(1, 2, 3), Add(3, 1, 2))) == (True,) + # FYI: If `Node` is made `unifiable_with_term` (along with `term`, + # `operator`, and `arguments` implementations), you'll get duplicate + # results in the following test (i.e. `(3, 3)`). + assert run(0, x, eq_assoccomm(Add(1, 2, 3), Add(1, 2, x))) == (3,) + assert run(0, x, eq_assoccomm(Add(1, 2, 3), Add(x, 2, 1))) == (3,) + + +@pytest.mark.xfail(strict=False) +@pytest.mark.timeout(5) +def test_eq_assoccomm_scaling(): + + commutative.index.clear() + commutative.facts.clear() + associative.index.clear() + associative.facts.clear() + + add = "add" + mul = "mul" + + fact(commutative, add) + fact(associative, add) + fact(commutative, mul) + fact(associative, mul) + + import random + + from tests.utils import generate_term + + random.seed(2343243) + + test_graph = generate_term((add, mul), range(4), 5) + + # Make a low-depth term inequal (e.g. inequal at base): + # Change the root operator + test_graph_2 = list(test_graph) + test_graph_2[0] = add if test_graph_2[0] == mul else mul + + assert test_graph != test_graph_2 + assert test_graph[1:] == test_graph_2[1:] + + assert run(0, True, eq_assoccomm(test_graph, test_graph_2)) == () + + # TODO: Make a high-depth term inequal diff --git a/tests/test_core.py b/tests/test_core.py index b1a761e..eefbd52 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -10,6 +10,7 @@ eq, fail, ground_order, + ground_order_seqs, ifa, lall, lany, @@ -18,6 +19,7 @@ ldisj, ldisj_seq, run, + shallow_ground_order_key, succeed, ) @@ -248,12 +250,53 @@ def test_ifa(): def test_ground_order(): x, y, z = var(), var(), var() + + assert shallow_ground_order_key({}, x) > shallow_ground_order_key({}, (x, y)) + assert shallow_ground_order_key({}, cons(x, y)) > shallow_ground_order_key( + {}, (x, y) + ) + assert shallow_ground_order_key({}, cons(1, 2)) < shallow_ground_order_key( + {}, (1, 2, 3, 4, y) + ) + assert shallow_ground_order_key({}, cons(1, 2)) == shallow_ground_order_key( + {}, (1, 2, 3, 4) + ) + assert shallow_ground_order_key({}, (x, y)) == shallow_ground_order_key( + {}, (x, y, z) + ) + assert run(0, x, ground_order((y, [1, z], 1), x)) == ([1, [1, z], y],) + a, b, c = var(), var(), var() assert run(0, (a, b, c), ground_order((y, [1, z], 1), (a, b, c))) == ( (1, [1, z], y), ) + res = run(0, z, ground_order([cons(x, y), (x, y)], z)) assert res == ([(x, y), cons(x, y)],) res = run(0, z, ground_order([(x, y), cons(x, y)], z)) assert res == ([(x, y), cons(x, y)],) + + +def test_ground_order_seq(): + + x, y, z = var(), var(), var() + a, b = var(), var() + res = run(0, (x, y), ground_order_seqs([a, (b, 2)], [x, y])) + assert res == ((a, (b, 2)),) + res = run(0, (x, y), ground_order_seqs([(a,), (b, 2)], [x, y])) + assert res == (((a,), (b, 2)),) + res = run(0, (x, y), ground_order_seqs([(a, b), (b, 2)], [x, y])) + assert res == (((b, a), (2, b)),) + res = run(0, (x, y), ground_order_seqs([(b, 2), (a, b)], [x, y])) + assert res == (((2, b), (b, a)),) + res = run(0, (x, y, z), ground_order_seqs([(b, 2), (a, b), (0, 1)], [x, y, z])) + assert res == (((2, b), (b, a), (1, 0)),) + res = run(0, (x, y), ground_order_seqs([(), ()], [x, y])) + assert res == (((), ()),) + res = run(0, (x, y), ground_order_seqs([(a, (1, b)), (b, 2)], [x, y])) + assert res == ((((1, b), a), (2, b)),) + res = run(0, (x, y), ground_order_seqs(["abc", "def"], [x, y])) + assert res == (("abc", "def"),) + res = run(0, (x, y), ground_order_seqs([[1, 2], (1, 2)], [x, y])) + assert res == (([1, 2], (1, 2)),) diff --git a/tests/test_graph.py b/tests/test_graph.py index 965240d..621b51f 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -11,7 +11,9 @@ from kanren import conde, eq, lall, run from kanren.constraints import isinstanceo +from kanren.core import Zzz from kanren.graph import eq_length, map_anyo, mapo, reduceo, walko +from kanren.term import applyo class OrderedFunction(object): @@ -71,11 +73,23 @@ def single_math_reduceo(expanded_term, reduced_term): math_reduceo = partial(reduceo, single_math_reduceo) -term_walko = partial( + +def walko_term_map_rel(walko_goal, x, y, **kwargs): + rator_in, rator_out = var(), var() + rands_in, rands_out = var(), var() + + return lall( + applyo(rator_in, rands_in, x), + applyo(rator_out, rands_out, y), + eq(rator_in, rator_out), + Zzz(map_anyo, walko_goal, rands_in, rands_out, null_res=False, **kwargs), + ) + + +walko_term = partial( walko, - rator_goal=eq, null_type=ExpressionTuple, - map_rel=partial(map_anyo, null_res=False), + map_rel=walko_term_map_rel, ) @@ -176,7 +190,7 @@ def test_map_anyo_types(): def test_map_anyo_misc(): - q_lv = var("q") + q_lv = var() res = run(0, q_lv, map_anyo(eq, [1, 2, 3], [1, 2, 3])) # TODO: Remove duplicate results @@ -201,13 +215,14 @@ def one_to_threeo(x, y): test_res = run(4, q_lv, map_anyo(math_reduceo, [1, etuple(add, 2, 2)], q_lv)) assert test_res == ([1, etuple(mul, 2, 2)],) - test_res = run(4, q_lv, map_anyo(math_reduceo, q_lv, var("z"))) + z = var() + test_res = run(4, q_lv, map_anyo(math_reduceo, q_lv, z)) assert all(isinstance(r, list) for r in test_res) - test_res = run(4, q_lv, map_anyo(math_reduceo, q_lv, var("z"), tuple)) + test_res = run(4, q_lv, map_anyo(math_reduceo, q_lv, z, null_type=tuple)) assert all(isinstance(r, tuple) for r in test_res) - x, y, z = var(), var(), var() + x, y = var(), var() def test_bin(a, b): return conde([eq(a, 1), eq(b, 2)]) @@ -231,6 +246,21 @@ def test_bin(a, b): assert s is not False assert all(isvar(i) for i in reify((x, y, z), s)) + # With ground ordering, this function should only be called once + n = 0 + + def eq_count(x, y): + nonlocal n + n += 1 + return eq(x, y) + + run(0, q_lv, map_anyo(eq_count, [x, y, 3], [y, x, 2])) + assert n == 1 + + n = 0 + run(0, q_lv, map_anyo(eq_count, [x, y, 3], [y, x, 2], use_ground_ordering=False)) + assert n == 3 + @pytest.mark.parametrize( "test_input, test_output", @@ -323,7 +353,7 @@ def test_map_anyo_reverse(): def test_walko_misc(): - q_lv = var(prefix="q") + q_lv = var() expr = etuple(add, etuple(mul, 2, 1), etuple(add, 1, 1)) res = run(0, q_lv, walko(eq, expr, expr)) @@ -353,6 +383,18 @@ def one_to_threeo(x, y): etuple(), ) + def map_rel(walk_g, x, y, **kwargs): + rator_in, rator_out = var(), var() + rands_in, rands_out = var(), var() + + return lall( + applyo(rator_in, rands_in, x), + applyo(rator_out, rands_out, y), + eq(rator_in, add), + eq(rator_out, add), + map_anyo(walk_g, x, y, **kwargs), + ) + res = run( 1, q_lv, @@ -366,7 +408,7 @@ def one_to_threeo(x, y): ), q_lv, # Only descend into `add` terms - rator_goal=lambda x, y: lall(eq(x, add), eq(y, add)), + map_rel=map_rel, ), ) @@ -376,6 +418,53 @@ def one_to_threeo(x, y): ), ) + # Now, we check that the `use_ground_order_seqs` option prevents infinite + # loops in `walko`. + + # This would go on forever between the two variable terms, even though + # the `(2, 3)` pair would cause it fail (if it were ever reached) + # run(1, True, walko(eq, [var(), 2], [var(), 3])) + run(1, True, walko(eq, [var(), 2], [var(), 3])) + + # Only walk rators for terms with the same car + def same_op(x, y, a, b): + rator_in = var() + rands_in, rands_out = var(), var() + + return conde( + [ + applyo(rator_in, rands_in, x), + applyo(rator_in, rands_out, y), + eq(a, rands_in), + eq(b, rands_out), + ], + [eq(a, None), eq(b, None)], + ) + + def one_to_threeo(x, y): + return conde([eq(x, 1), eq(y, 3)], [eq(x, None), eq(y, None)]) + + # This won't work without the pre-processing function, because + # `one_to_threeo` will fail when given `add, add` arguments + assert ( + run(0, True, walko(one_to_threeo, [add, 1, 1], [add, 3, 3], map_rel=mapo)) == () + ) + + assert ( + run( + 1, + True, + walko( + one_to_threeo, + [add, 1, 1], + [add, 3, 3], + pre_process_fn=same_op, + map_rel=mapo, + ), + ) + == (True,) + ) + @pytest.mark.parametrize( "test_input, test_output", @@ -413,11 +502,11 @@ def test_walko(test_input, test_output): """Test `walko` with fully ground terms (i.e. no logic variables).""" q_lv = var() - term_walko_fp = partial(reduceo, partial(term_walko, single_math_reduceo)) + walko_term_fp = partial(reduceo, partial(walko_term, single_math_reduceo)) test_res = run( len(test_output), q_lv, - term_walko_fp(test_input, q_lv), + walko_term_fp(test_input, q_lv), results_filter=toolz.unique, ) @@ -438,7 +527,7 @@ def test_walko_reverse(): """Test `walko` in "reverse" (i.e. specify the reduced form and generate the un-reduced form).""" # noqa: E501 q_lv = var("q") - test_res = run(2, q_lv, term_walko(math_reduceo, q_lv, 5)) + test_res = run(2, q_lv, walko_term(math_reduceo, q_lv, 5)) assert test_res == ( etuple(log, etuple(exp, 5)), etuple(log, etuple(exp, etuple(log, etuple(exp, 5)))), @@ -446,7 +535,7 @@ def test_walko_reverse(): assert all(e.eval_obj == 5.0 for e in test_res) # Make sure we get some variety in the results - test_res = run(2, q_lv, term_walko(math_reduceo, q_lv, etuple(mul, 2, 5))) + test_res = run(2, q_lv, walko_term(math_reduceo, q_lv, etuple(mul, 2, 5))) assert test_res == ( # Expansion of the term's root etuple(add, 5, 5), @@ -460,7 +549,7 @@ def test_walko_reverse(): assert all(e.eval_obj == 10.0 for e in test_res) r_lv = var("r") - test_res = run(4, [q_lv, r_lv], term_walko(math_reduceo, q_lv, r_lv)) + test_res = run(4, [q_lv, r_lv], walko_term(math_reduceo, q_lv, r_lv)) expect_res = ( [etuple(add, 1, 1), etuple(mul, 2, 1)], [etuple(log, etuple(exp, etuple(add, 1, 1))), etuple(mul, 2, 1)], diff --git a/tests/test_term.py b/tests/test_term.py index 740b645..d28df26 100644 --- a/tests/test_term.py +++ b/tests/test_term.py @@ -1,68 +1,11 @@ from cons import cons +from cons.core import ConsType from etuples import etuple from unification import reify, unify, var -from kanren.core import run -from kanren.term import applyo, arguments, operator, term, unifiable_with_term - - -@unifiable_with_term -class Node(object): - def __init__(self, op, args): - self.op = op - self.args = args - - def __eq__(self, other): - return ( - type(self) == type(other) - and self.op == other.op - and self.args == other.args - ) - - def __hash__(self): - return hash((type(self), self.op, self.args)) - - def __str__(self): - return "%s(%s)" % (self.op.name, ", ".join(map(str, self.args))) - - __repr__ = __str__ - - -class Operator(object): - def __init__(self, name): - self.name = name - - -Add = Operator("add") -Mul = Operator("mul") - - -def add(*args): - return Node(Add, args) - - -def mul(*args): - return Node(Mul, args) - - -class Op(object): - def __init__(self, name): - self.name = name - - -@arguments.register(Node) -def arguments_Node(t): - return t.args - - -@operator.register(Node) -def operator_Node(t): - return t.op - - -@term.register(Operator, (list, tuple)) -def term_Op(op, args): - return Node(op, args) +from kanren.core import run, shallow_ground_order_key +from kanren.term import TermType, applyo, arguments, operator, term +from tests.utils import Add, Node, Operator def test_applyo(): @@ -103,21 +46,59 @@ def test_applyo(): def test_applyo_object(): x = var() - assert run(0, x, applyo(Add, (1, 2, 3), x)) == (add(1, 2, 3),) - assert run(0, x, applyo(x, (1, 2, 3), add(1, 2, 3))) == (Add,) - assert run(0, x, applyo(Add, x, add(1, 2, 3))) == ((1, 2, 3),) + assert run(0, x, applyo(Add, (1, 2, 3), x)) == (Add(1, 2, 3),) + assert run(0, x, applyo(x, (1, 2, 3), Add(1, 2, 3))) == (Add,) + assert run(0, x, applyo(Add, x, Add(1, 2, 3))) == ((1, 2, 3),) -def test_unifiable_with_term(): - add = Operator("add") - t = Node(add, (1, 2)) +def test_term_dispatch(): + + t = Node(Add, (1, 2)) assert arguments(t) == (1, 2) - assert operator(t) == add + assert operator(t) == Add assert term(operator(t), arguments(t)) == t + +def test_unifiable_with_term(): + from kanren.term import unifiable_with_term + + @unifiable_with_term + class NewNode(Node): + pass + + class NewOperator(Operator): + def __call__(self, *args): + return NewNode(self, args) + + NewAdd = NewOperator("newadd") + x = var() - s = unify(Node(add, (1, x)), Node(add, (1, 2)), {}) + s = unify(NewNode(NewAdd, (1, x)), NewNode(NewAdd, (1, 2)), {}) assert s == {x: 2} - assert reify(Node(add, (1, x)), s) == Node(add, (1, 2)) + assert reify(NewNode(NewAdd, (1, x)), s) == NewNode(NewAdd, (1, 2)) + + +def test_TermType(): + assert issubclass(type(Add(1, 2)), TermType) + assert isinstance(Add(1, 2), TermType) + # assert not issubclass(type([1, 2]), TermType) + # assert not isinstance([1, 2], TermType) + assert issubclass(type([1, 2]), TermType) + assert isinstance([1, 2], TermType) + assert not isinstance(ConsType, TermType) + assert not issubclass(type(ConsType), TermType) + + +def test_shallow_ground_order(): + + x, y, z = var(), var(), var() + + assert shallow_ground_order_key({}, x) > shallow_ground_order_key({}, Add(x, y)) + assert shallow_ground_order_key({}, cons(x, y)) > shallow_ground_order_key( + {}, Add(x, y) + ) + assert shallow_ground_order_key({}, Add(x, y)) == shallow_ground_order_key( + {}, Add(x, y, z) + ) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..4f04c41 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,88 @@ +import random + +from kanren.term import arguments, operator + + +class Node(object): + def __init__(self, op, args): + self.op = op + self.args = args + + def __eq__(self, other): + return ( + type(self) == type(other) + and self.op == other.op + and self.args == other.args + ) + + def __hash__(self): + return hash((type(self), self.op, self.args)) + + def __str__(self): + return "%s(%s)" % (self.op.name, ", ".join(map(str, self.args))) + + __repr__ = __str__ + + +class Operator(object): + def __init__(self, name): + self.name = name + + def __call__(self, *args): + return Node(self, args) + + def __eq__(self, other): + return type(self) == type(other) and self.name == other.name + + def __hash__(self): + return hash((type(self), self.name)) + + def __str__(self): + return self.name + + __repr__ = __str__ + + +Add = Operator("add") +Mul = Operator("mul") + + +@arguments.register(Node) +def arguments_Node(t): + return t.args + + +@operator.register(Node) +def operator_Node(t): + return t.op + + +# @term.register(Operator, Sequence) +# def term_Operator(op, args): +# return Node(op, args) + + +def generate_term(ops, args, i=10, gen_fn=None): + + if gen_fn is not None: + + gen_res = gen_fn(ops, args, i) + + if gen_res is not None: + return gen_res + + g_op = random.choice(ops) + + if i > 0: + num_sub_graphs = len(args) // 2 + else: + num_sub_graphs = 0 + + g_args = random.sample(args, len(args) - num_sub_graphs) + g_args += [ + generate_term(ops, args, i=i - 1, gen_fn=gen_fn) for s in range(num_sub_graphs) + ] + + random.shuffle(g_args) + + return [g_op] + list(g_args)