Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 205 additions & 96 deletions kanren/assoccomm.py

Large diffs are not rendered by default.

113 changes: 89 additions & 24 deletions kanren/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections.abc import Generator, Sequence
from collections.abc import Generator, Mapping, Sequence
from functools import partial, reduce
from itertools import tee
from operator import length_hint
from statistics import mean

from cons.core import ConsPair
from cons.core import ConsPair, car, cdr
from multipledispatch import dispatch
from toolz import interleave, take
from unification import isvar, reify, unify
from unification.core import isground


def fail(s):
Expand Down Expand Up @@ -113,18 +114,25 @@ def conde(*goals):
lany = ldisj


def ground_order_key(S, x):
@dispatch(Mapping, object)
def shallow_ground_order_key(S, x):
if isvar(x):
return 2
elif isground(x, S):
return -1
elif issubclass(type(x), ConsPair):
return 1
else:
return 0


def ground_order(in_args, out_args):
return 10
elif isinstance(x, ConsPair):
val = 0
val += 1 if isvar(car(x)) else 0
cdr_x = cdr(x)
if issubclass(type(x), ConsPair):
val += 2 if isvar(cdr_x) else 0
elif len(cdr_x) == 1:
val += 1 if isvar(cdr_x[0]) else 0
elif len(cdr_x) > 1:
val += mean(1.0 if isvar(i) else 0.0 for i in cdr_x)
return val
return 0


def ground_order(in_args, out_args, key_fn=shallow_ground_order_key):
"""Construct a non-relational goal that orders a list of terms based on groundedness (grounded precede ungrounded).""" # noqa: E501

def ground_order_goal(S):
Expand All @@ -134,7 +142,7 @@ def ground_order_goal(S):

S_new = unify(
list(out_args_rf) if isinstance(out_args_rf, Sequence) else out_args_rf,
sorted(in_args_rf, key=partial(ground_order_key, S)),
sorted(in_args_rf, key=partial(key_fn, S)),
S,
)

Expand All @@ -144,6 +152,48 @@ def ground_order_goal(S):
return ground_order_goal


def ground_order_seqs(in_seqs, out_seqs, key_fn=shallow_ground_order_key):
"""Construct a non-relational goal that orders lists of sequences based on the groundedness of their corresponding terms. # noqa: E501

>>> from unification import var
>>> x, y = var('x'), var('y')
>>> a, b = var('a'), var('b')
>>> run(0, (x, y), ground_order_seqs([(a, b), (b, 2)], [x, y]))
(((~b, ~a), (2, ~b)),)
"""

def ground_order_seqs_goal(S):
nonlocal in_seqs, out_seqs, key_fn

in_seqs_rf, out_seqs_rf = reify((in_seqs, out_seqs), S)

if (
not any(isinstance(s, str) for s in in_seqs_rf)
and reduce(
lambda x, y: x == y and y, (length_hint(s, -1) for s in in_seqs_rf)
)
> 0
):

in_seqs_ord = zip(*sorted(zip(*in_seqs_rf), key=partial(key_fn, S)))
S_new = unify(
list(out_seqs_rf),
[type(j)(i) for i, j in zip(in_seqs_ord, in_seqs_rf)],
S,
)

if S_new is not False:
yield S_new
else:

S_new = unify(out_seqs_rf, in_seqs_rf, S)

if S_new is not False:
yield S_new

return ground_order_seqs_goal


def ifa(g1, g2):
"""Create a goal operator that returns the first stream unless it fails."""

Expand Down Expand Up @@ -204,22 +254,37 @@ def run(n, x, *goals, results_filter=None):
return tuple(take(n, results))


def dbgo(*args, msg=None): # pragma: no cover
"""Construct a goal that sets a debug trace and prints reified arguments."""
def dbgo(*args, msg=None, pdb=False, print_asap=True, trace=True): # pragma: no cover
"""Construct a goal that prints reified arguments and, optionally, sets a debug trace.""" # noqa: E501
from pprint import pprint

from unification import var

trace_var = var("__dbgo_trace")

def dbgo_goal(S):
nonlocal args
args = reify(args, S)
nonlocal args, msg, pdb, print_asap, trace_var, trace

args_rf, trace_rf = reify((args, trace_var), S)

if trace:
S = S.copy()
if isvar(trace_rf):
S[trace_var] = [(msg, tuple(str(a) for a in args_rf))]
else:
trace_rf.append((msg, tuple(str(a) for a in args_rf)))
S[trace_var] = trace_rf

if msg is not None:
print(msg)
if print_asap:
if msg is not None:
print(msg)
pprint(args_rf)

pprint(args)
if pdb:
import pdb

import pdb
pdb.set_trace()

pdb.set_trace()
yield S

return dbgo_goal
Loading