diff --git a/mk.lua b/mk.lua index f874240..9417c42 100644 --- a/mk.lua +++ b/mk.lua @@ -167,15 +167,22 @@ local function walk_all(v, s) end end +local function occurs_check(var, term) + if not is_table(term) then return false end + if #term == 0 then return false end + if is_var(term) then return var == term end + return occurs_check(var, car(term)) or occurs_check(var, cdr(term)) +end + local function unify(k, v, s) local k = walk(k, s) local v = walk(v, s) if k == v then return s elseif is_var(k) then - return cons(cons(k, v), s) + return (not occurs_check(k, v)) and cons(cons(k, v), s) elseif is_var(v) then - return cons(cons(v, k), s) + return (not occurs_check(v, k)) and cons(cons(v, k), s) elseif is_var(k) or is_var(v) then return cons(cons(k, v), s) elseif is_pair(v) and is_pair(k) then diff --git a/test.lua b/test.lua index f1cc459..44ba261 100644 --- a/test.lua +++ b/test.lua @@ -105,6 +105,10 @@ assert(equal(run(false, {a, b, c}, all(eq(a, 2), eq(a, b), not_eq(b, 3))), {{2, assert(equal(run(false, {a, b, c}, all(eq(c, 2), eq(a, b), not_eq(b, 3))), { { "_.0 not eq: 3", "_.0 not eq: 3", 2 } })) assert(equal(run(false, {a, b}, all(eq(a, b), not_eq(a, 3))), { { "_.0 not eq: 3", "_.0 not eq: 3" } })) +assert(equal(run(1, a, eq(a, list(a, a))), {})) +assert(equal(run(1, a, eq(a, list(a, b))), {})) +assert(equal(run(1, a, eq(a, list(b, a))), {})) + assert(equal(run(1, a, mergeo({1, {2, {3}}}, {4, {5, {6, {}}}}, a)), { list(1, 2, 3, 4, 5, 6) })) assert(equal(run(1, a, mergeo({1}, {2}, a)), { {1, {2}} }))