diff --git a/src/categorical_algebra/HomSearch.jl b/src/categorical_algebra/HomSearch.jl index d197fb24f..a232bb69f 100644 --- a/src/categorical_algebra/HomSearch.jl +++ b/src/categorical_algebra/HomSearch.jl @@ -10,7 +10,8 @@ export ACSetHomomorphismAlgorithm, BacktrackingSearch, HomomorphismQuery, homomorphism, homomorphisms, is_homomorphic, isomorphism, isomorphisms, is_isomorphic, @acset_transformation, @acset_transformations, - subobject_graph, partial_overlaps, maximum_common_subobject + subobject_graph, partial_overlaps, maximum_common_subobject, + debug_homomorphisms using ...Theories, ..CSets, ..FinSets, ..FreeDiagrams, ..Subobjects using ...Graphs.BasicGraphs @@ -20,7 +21,7 @@ using ACSets.DenseACSets: attrtype_type, delete_subobj! using Random using CompTime using MLStyle: @match -using DataStructures: BinaryHeap, DefaultDict +using DataStructures: BinaryHeap, DefaultDict, OrderedDict # Finding C-set transformations ############################### @@ -85,7 +86,7 @@ homomorphism(X::ACSet, Y::ACSet; alg=BacktrackingSearch(), kw...) = function homomorphism(X::ACSet, Y::ACSet, alg::BacktrackingSearch; kw...) result = nothing backtracking_search(X, Y; kw...) do α - result = α; return true + result = get_hom(α); return true end result end @@ -101,11 +102,19 @@ homomorphisms(X::ACSet, Y::ACSet; alg=BacktrackingSearch(), kw...) = function homomorphisms(X::ACSet, Y::ACSet, alg::BacktrackingSearch; kw...) results = [] backtracking_search(X, Y; kw...) do α - push!(results, map_components(deepcopy, α)); return false + push!(results, map_components(deepcopy, get_hom(α))); return false end results end +function debug_homomorphisms(X::ACSet, Y::ACSet; kw...) + results = [] + m = backtracking_search(X, Y; debug=true, kw...) do α + push!(results, map_components(deepcopy, get_hom(α))); return false + end + results => m.debug +end + """ Is the first attributed ``C``-set homomorphic to the second? This function generally reduces to [`homomorphism`](@ref) but certain algorithms @@ -152,6 +161,60 @@ is_isomorphic(X::ACSet, Y::ACSet, alg::BacktrackingSearch; kw...) = # Backtracking search #-------------------- +"""Keep track of progress through backtracking homomorphism search.""" +mutable struct BacktrackingTree + node::Union{Nothing,Pair{Symbol,Int}} + success::Bool + asgn::NamedTuple + children::OrderedDict{Int,BacktrackingTree} + BacktrackingTree() = new(nothing, false, (;), OrderedDict{Int,BacktrackingTree}()) +end + +"""A backtracking tree plus a pointer to a node in the tree""" +struct BacktrackingTreePt + t::BacktrackingTree + curr::Vector{Int} + BacktrackingTreePt() = new(BacktrackingTree(),Int[]) +end + +function Base.push!(tc::BacktrackingTreePt, c::Symbol, x::Int, y::Int, asgn) + t = tc.t[tc.curr] + t.node = c => x + t.children[y] = BacktrackingTree() + t.children[y].asgn = deepcopy(asgn) + push!(tc.curr, y) + return true +end + +function Base.delete!(tc::BacktrackingTreePt, c::Symbol, x::Int, y::Int) + t = tc.t[tc.curr[1:end-1]] + t.node == (c=>x) || error("Bad remove $c#$x->$y") + pop!(tc.curr) +end + +function success(tc::BacktrackingTreePt) + tc.t[tc.curr].success = true +end + +function Base.show(io::IO, t::BacktrackingTree) + if !isnothing(t.node) + print(io,"{"); print(io, t.node[1]); print(io, t.node[2]); print(io,"}"); + end + print(io, "[") + for (k,v) in collect(t.children) + print(io, k); print(io, v); print(io, ",") + end + if !isempty(t.children) print(io,"\b") end + print(io,"]") +end + +function Base.getindex(t::BacktrackingTree, curr::Vector{Int}) + for c in curr + t = t.children[c] + end + t +end + """ Get assignment pairs from partially specified component of C-set morphism. """ partial_assignments(x::FinFunction; is_attr=false) = partial_assignments(collect(x)) @@ -177,10 +240,25 @@ struct BacktrackingState{ dom::Dom codom::Codom type_components::LooseFun + debug::Union{Nothing,BacktrackingTreePt} +end + +"""Extract an ACSetTransformation from BacktrackingState""" +function get_hom(state::BacktrackingState) + if any(!=(identity), state.type_components) + return LooseACSetTransformation( + state.assignment, state.type_components, state.dom, state.codom) + else + S = acset_schema(state.dom) + od = Dict{Symbol,Vector{Int}}(k=>(state.assignment[k]) for k in objects(S)) + ad = Dict(k=>last.(state.assignment[k]) for k in attrtypes(S)) + comps = merge(NamedTuple(od),NamedTuple(ad)) + return ACSetTransformation(comps, state.dom, state.codom) + end end function backtracking_search(f, X::ACSet, Y::ACSet; - monic=false, iso=false, random=false, + monic=false, iso=false, random=false, debug=false, type_components=(;), initial=(;), error_failures=false) S, Sy = acset_schema.([X,Y]) S == Sy || error("Schemas must match for morphism search") @@ -235,9 +313,11 @@ function backtracking_search(f, X::ACSet, Y::ACSet; inv_assignment = NamedTuple{ObAttr}( (c in monic ? zeros(Int, nparts(Y, c)) : nothing) for c in ObAttr) loosefuns = NamedTuple{Attr}( - isnothing(type_components) ? identity : get(type_components, c, identity) for c in Attr) - state = BacktrackingState(assignment, assignment_depth, - inv_assignment, X, Y, loosefuns) + isnothing(type_components) ? identity : get(type_components, c, identity) + for c in Attr) + + state = BacktrackingState(assignment, assignment_depth, inv_assignment, X, Y, + loosefuns, debug ? BacktrackingTreePt() : nothing) # Make any initial assignments, failing immediately if inconsistent. for (c, c_assignments) in pairs(initial) @@ -252,39 +332,32 @@ function backtracking_search(f, X::ACSet, Y::ACSet; end end # Start the main recursion for backtracking search. - backtracking_search(f, state, 1; random=random) + backtracking_search(f, state, 1; random=random, toplevel=true) end function backtracking_search(f, state::BacktrackingState, depth::Int; - random=false) + random=false, toplevel=false) # Choose the next unassigned element. mrv, mrv_elem = find_mrv_elem(state, depth) if isnothing(mrv_elem) - # No unassigned elements remain, so we have a complete assignment. - if any(!=(identity), state.type_components) - return f(LooseACSetTransformation( - state.assignment, state.type_components, state.dom, state.codom)) - else - S = acset_schema(state.dom) - od = Dict{Symbol,Vector{Int}}(k=>(state.assignment[k]) for k in objects(S)) - ad = Dict(k=>last.(state.assignment[k]) for k in attrtypes(S)) - comps = merge(NamedTuple(od),NamedTuple(ad)) - return f(ACSetTransformation(comps, state.dom, state.codom)) - end + isnothing(state.debug) || success(state.debug) + return f(state) elseif mrv == 0 # An element has no allowable assignment, so we must backtrack. return false end - c, x = mrv_elem + c, x, ys = mrv_elem # Attempt all assignments of the chosen element. Y = state.codom - for y in (random ? shuffle : identity)(parts(Y, c)) + for y in (random ? shuffle : identity)(ys) (assign_elem!(state, depth, c, x, y) + && (isnothing(state.debug) ? true : push!(state.debug, c, x, y, state.assignment)) && backtracking_search(f, state, depth + 1)) && return true unassign_elem!(state, depth, c, x) + isnothing(state.debug) || delete!(state.debug, c, x, state.assignment[c][x]) end - return false + return toplevel ? state : false # return state to recover debug tree end """ Find an unassigned element having the minimum remaining values (MRV). @@ -295,9 +368,12 @@ function find_mrv_elem(state::BacktrackingState, depth) Y = state.codom for c in ob(S), (x, y) in enumerate(state.assignment[c]) y == 0 || continue - n = count(can_assign_elem(state, depth, c, x, y) for y in parts(Y, c)) + ys = filter(parts(Y,c)) do y + can_assign_elem(state, depth, c, x, y) + end + n = length(ys) if n < mrv - mrv, mrv_elem = n, (c, x) + mrv, mrv_elem = n, (c, x, ys) end end (mrv, mrv_elem) diff --git a/src/graphics/GraphvizCategories.jl b/src/graphics/GraphvizCategories.jl index 330a72fa7..d9103cd59 100644 --- a/src/graphics/GraphvizCategories.jl +++ b/src/graphics/GraphvizCategories.jl @@ -6,6 +6,7 @@ export to_graphviz, to_graphviz_property_graph using ...GATs, ...Theories, ...CategoricalAlgebra, ...Graphs, ..GraphvizGraphs import ..Graphviz import ..GraphvizGraphs: to_graphviz, to_graphviz_property_graph +using ...CategoricalAlgebra.HomSearch: BacktrackingTree, BacktrackingTreePt # Presentations ############### @@ -143,4 +144,37 @@ function to_graphviz(f::FinFunction{Int,Int}; kw...) to_graphviz(g; kw...) end +# Search trees +############### +to_graphviz(t::BacktrackingTreePt) = to_graphviz(t.t) + +function to_graphviz(t::BacktrackingTree) + pg = PropertyGraph{Any}(; + prog = "dot", + graph = Dict(), + node = merge!(Dict(:shape => "box", :width => ".1", :height => ".1", + :margin => "0.025", :style=>"filled")), + edge = Dict()) + kwargs(tr::BacktrackingTree) = ( + fillcolor=tr.success ? "green" : "red", + tooltip=isempty(tr.asgn) ? "" : string(tr.asgn), + label = isnothing(tr.node) ? "" : join(string.([tr.node...]))) + add_vertex!(pg; kwargs(t)...) + queue = [Int[]] + paths = Dict([Int[]=>1]) # path to vertex + while !isempty(queue) + curr = popfirst!(queue) + subt = t[curr] + # We ought print the index too, but graphviz renders edges in right order + for (_,e) in enumerate(keys(subt.children)) + new_pth = [curr...,e] + v = add_vertex!(pg; kwargs(t[new_pth])...) + paths[new_pth] = v + add_edge!(pg, paths[curr], v; label=string("$e")) + push!(queue, new_pth) + end + end + to_graphviz(pg) +end + end diff --git a/test/categorical_algebra/HomSearch.jl b/test/categorical_algebra/HomSearch.jl index 865409615..739389def 100644 --- a/test/categorical_algebra/HomSearch.jl +++ b/test/categorical_algebra/HomSearch.jl @@ -147,6 +147,42 @@ end @test length(@acset_transformations x x begin V = Dict(1=>1) end monic = [:E]) == 2 @test_throws ErrorException @acset_transformation g h begin V = [4,3,2,1]; E = [1,2,3,4] end +# Debug graph +#------------ +@present SchTri <: SchGraph begin + T::Ob + (t1,t2,t3)::Hom(T,E) + t1 ⋅ src == t2 ⋅ src + t1 ⋅ tgt == t3 ⋅ tgt + t2 ⋅ src == t3 ⋅ src +end + +@acset_type Tri(SchTri) + +""" e₃ + 2 ← 4 +e₁↑ ↖ ↓ e₄ + 1 → 3 + e₂ +""" +quad = @acset Tri begin V=4; E=5; T=2; + src=[1,1,4,4,3]; tgt=[2,3,2,3,2]; + t1=[1,3]; t2=[2,4]; t3=[5,5] +end + +term = apex(terminal(Tri)) + +tri5 = @acset Tri begin + V=2; E=3; T=5; src=[1,1,2]; tgt=[2,2,2]; t1=1; t2=2; t3=3 +end + +tri = @acset Tri begin + V=3; E=3; T=1; src=[1,1,2]; tgt=[3,2,3]; t1=1; t2=2; t3=3 +end + +hs, t = debug_homomorphisms(tri, quad ⊕ tri5; monic=false) +@test length(hs) == length(homomorphisms(tri, quad ⊕ tri5)) +# to_graphviz(t) # Enumeration of subobjects ###########################