Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 102 additions & 26 deletions src/categorical_algebra/HomSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ export ACSetHomomorphismAlgorithm, BacktrackingSearch, HomomorphismQuery,
homomorphism, homomorphisms, is_homomorphic,
isomorphism, isomorphisms, is_isomorphic,
@acset_transformation, @acset_transformations,
subobject_graph, partial_overlaps, maximum_common_subobject
subobject_graph, partial_overlaps, maximum_common_subobject,
debug_homomorphisms

using ...Theories, ..CSets, ..FinSets, ..FreeDiagrams, ..Subobjects
using ...Graphs.BasicGraphs
Expand All @@ -20,7 +21,7 @@ using ACSets.DenseACSets: attrtype_type, delete_subobj!
using Random
using CompTime
using MLStyle: @match
using DataStructures: BinaryHeap, DefaultDict
using DataStructures: BinaryHeap, DefaultDict, OrderedDict

# Finding C-set transformations
###############################
Expand Down Expand Up @@ -85,7 +86,7 @@ homomorphism(X::ACSet, Y::ACSet; alg=BacktrackingSearch(), kw...) =
function homomorphism(X::ACSet, Y::ACSet, alg::BacktrackingSearch; kw...)
result = nothing
backtracking_search(X, Y; kw...) do α
result = α; return true
result = get_hom(α); return true
end
result
end
Expand All @@ -101,11 +102,19 @@ homomorphisms(X::ACSet, Y::ACSet; alg=BacktrackingSearch(), kw...) =
function homomorphisms(X::ACSet, Y::ACSet, alg::BacktrackingSearch; kw...)
results = []
backtracking_search(X, Y; kw...) do α
push!(results, map_components(deepcopy, α)); return false
push!(results, map_components(deepcopy, get_hom(α))); return false
end
results
end

function debug_homomorphisms(X::ACSet, Y::ACSet; kw...)
results = []
m = backtracking_search(X, Y; debug=true, kw...) do α
push!(results, map_components(deepcopy, get_hom(α))); return false
end
results => m.debug
end

""" Is the first attributed ``C``-set homomorphic to the second?

This function generally reduces to [`homomorphism`](@ref) but certain algorithms
Expand Down Expand Up @@ -152,6 +161,60 @@ is_isomorphic(X::ACSet, Y::ACSet, alg::BacktrackingSearch; kw...) =
# Backtracking search
#--------------------

"""Keep track of progress through backtracking homomorphism search."""
mutable struct BacktrackingTree
node::Union{Nothing,Pair{Symbol,Int}}
success::Bool
asgn::NamedTuple
children::OrderedDict{Int,BacktrackingTree}
BacktrackingTree() = new(nothing, false, (;), OrderedDict{Int,BacktrackingTree}())
end

"""A backtracking tree plus a pointer to a node in the tree"""
struct BacktrackingTreePt
t::BacktrackingTree
curr::Vector{Int}
BacktrackingTreePt() = new(BacktrackingTree(),Int[])
end

function Base.push!(tc::BacktrackingTreePt, c::Symbol, x::Int, y::Int, asgn)
t = tc.t[tc.curr]
t.node = c => x
t.children[y] = BacktrackingTree()
t.children[y].asgn = deepcopy(asgn)
push!(tc.curr, y)
return true
end

function Base.delete!(tc::BacktrackingTreePt, c::Symbol, x::Int, y::Int)
t = tc.t[tc.curr[1:end-1]]
t.node == (c=>x) || error("Bad remove $c#$x->$y")
pop!(tc.curr)
end

function success(tc::BacktrackingTreePt)
tc.t[tc.curr].success = true
end

function Base.show(io::IO, t::BacktrackingTree)
if !isnothing(t.node)
print(io,"{"); print(io, t.node[1]); print(io, t.node[2]); print(io,"}");
end
print(io, "[")
for (k,v) in collect(t.children)
print(io, k); print(io, v); print(io, ",")
end
if !isempty(t.children) print(io,"\b") end
print(io,"]")
end

function Base.getindex(t::BacktrackingTree, curr::Vector{Int})
for c in curr
t = t.children[c]
end
t
end

""" Get assignment pairs from partially specified component of C-set morphism.
"""
partial_assignments(x::FinFunction; is_attr=false) = partial_assignments(collect(x))
Expand All @@ -177,10 +240,25 @@ struct BacktrackingState{
dom::Dom
codom::Codom
type_components::LooseFun
debug::Union{Nothing,BacktrackingTreePt}
end

"""Extract an ACSetTransformation from BacktrackingState"""
function get_hom(state::BacktrackingState)
if any(!=(identity), state.type_components)
return LooseACSetTransformation(
state.assignment, state.type_components, state.dom, state.codom)
else
S = acset_schema(state.dom)
od = Dict{Symbol,Vector{Int}}(k=>(state.assignment[k]) for k in objects(S))
ad = Dict(k=>last.(state.assignment[k]) for k in attrtypes(S))
comps = merge(NamedTuple(od),NamedTuple(ad))
return ACSetTransformation(comps, state.dom, state.codom)
end
end

function backtracking_search(f, X::ACSet, Y::ACSet;
monic=false, iso=false, random=false,
monic=false, iso=false, random=false, debug=false,
type_components=(;), initial=(;), error_failures=false)
S, Sy = acset_schema.([X,Y])
S == Sy || error("Schemas must match for morphism search")
Expand Down Expand Up @@ -235,9 +313,11 @@ function backtracking_search(f, X::ACSet, Y::ACSet;
inv_assignment = NamedTuple{ObAttr}(
(c in monic ? zeros(Int, nparts(Y, c)) : nothing) for c in ObAttr)
loosefuns = NamedTuple{Attr}(
isnothing(type_components) ? identity : get(type_components, c, identity) for c in Attr)
state = BacktrackingState(assignment, assignment_depth,
inv_assignment, X, Y, loosefuns)
isnothing(type_components) ? identity : get(type_components, c, identity)
for c in Attr)

state = BacktrackingState(assignment, assignment_depth, inv_assignment, X, Y,
loosefuns, debug ? BacktrackingTreePt() : nothing)

# Make any initial assignments, failing immediately if inconsistent.
for (c, c_assignments) in pairs(initial)
Expand All @@ -252,39 +332,32 @@ function backtracking_search(f, X::ACSet, Y::ACSet;
end
end
# Start the main recursion for backtracking search.
backtracking_search(f, state, 1; random=random)
backtracking_search(f, state, 1; random=random, toplevel=true)
end

function backtracking_search(f, state::BacktrackingState, depth::Int;
random=false)
random=false, toplevel=false)
# Choose the next unassigned element.
mrv, mrv_elem = find_mrv_elem(state, depth)
if isnothing(mrv_elem)
# No unassigned elements remain, so we have a complete assignment.
if any(!=(identity), state.type_components)
return f(LooseACSetTransformation(
state.assignment, state.type_components, state.dom, state.codom))
else
S = acset_schema(state.dom)
od = Dict{Symbol,Vector{Int}}(k=>(state.assignment[k]) for k in objects(S))
ad = Dict(k=>last.(state.assignment[k]) for k in attrtypes(S))
comps = merge(NamedTuple(od),NamedTuple(ad))
return f(ACSetTransformation(comps, state.dom, state.codom))
end
isnothing(state.debug) || success(state.debug)
return f(state)
elseif mrv == 0
# An element has no allowable assignment, so we must backtrack.
return false
end
c, x = mrv_elem
c, x, ys = mrv_elem

# Attempt all assignments of the chosen element.
Y = state.codom
for y in (random ? shuffle : identity)(parts(Y, c))
for y in (random ? shuffle : identity)(ys)
(assign_elem!(state, depth, c, x, y)
&& (isnothing(state.debug) ? true : push!(state.debug, c, x, y, state.assignment))
&& backtracking_search(f, state, depth + 1)) && return true
unassign_elem!(state, depth, c, x)
isnothing(state.debug) || delete!(state.debug, c, x, state.assignment[c][x])
end
return false
return toplevel ? state : false # return state to recover debug tree
end

""" Find an unassigned element having the minimum remaining values (MRV).
Expand All @@ -295,9 +368,12 @@ function find_mrv_elem(state::BacktrackingState, depth)
Y = state.codom
for c in ob(S), (x, y) in enumerate(state.assignment[c])
y == 0 || continue
n = count(can_assign_elem(state, depth, c, x, y) for y in parts(Y, c))
ys = filter(parts(Y,c)) do y
can_assign_elem(state, depth, c, x, y)
end
n = length(ys)
if n < mrv
mrv, mrv_elem = n, (c, x)
mrv, mrv_elem = n, (c, x, ys)
end
end
(mrv, mrv_elem)
Expand Down
34 changes: 34 additions & 0 deletions src/graphics/GraphvizCategories.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export to_graphviz, to_graphviz_property_graph
using ...GATs, ...Theories, ...CategoricalAlgebra, ...Graphs, ..GraphvizGraphs
import ..Graphviz
import ..GraphvizGraphs: to_graphviz, to_graphviz_property_graph
using ...CategoricalAlgebra.HomSearch: BacktrackingTree, BacktrackingTreePt

# Presentations
###############
Expand Down Expand Up @@ -143,4 +144,37 @@ function to_graphviz(f::FinFunction{Int,Int}; kw...)
to_graphviz(g; kw...)
end

# Search trees
###############
to_graphviz(t::BacktrackingTreePt) = to_graphviz(t.t)

function to_graphviz(t::BacktrackingTree)
pg = PropertyGraph{Any}(;
prog = "dot",
graph = Dict(),
node = merge!(Dict(:shape => "box", :width => ".1", :height => ".1",
:margin => "0.025", :style=>"filled")),
edge = Dict())
kwargs(tr::BacktrackingTree) = (
fillcolor=tr.success ? "green" : "red",
tooltip=isempty(tr.asgn) ? "" : string(tr.asgn),
label = isnothing(tr.node) ? "" : join(string.([tr.node...])))
add_vertex!(pg; kwargs(t)...)
queue = [Int[]]
paths = Dict([Int[]=>1]) # path to vertex
while !isempty(queue)
curr = popfirst!(queue)
subt = t[curr]
# We ought print the index too, but graphviz renders edges in right order
for (_,e) in enumerate(keys(subt.children))
new_pth = [curr...,e]
v = add_vertex!(pg; kwargs(t[new_pth])...)
paths[new_pth] = v
add_edge!(pg, paths[curr], v; label=string("$e"))
push!(queue, new_pth)
end
end
to_graphviz(pg)
end

end
36 changes: 36 additions & 0 deletions test/categorical_algebra/HomSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,42 @@ end
@test length(@acset_transformations x x begin V = Dict(1=>1) end monic = [:E]) == 2
@test_throws ErrorException @acset_transformation g h begin V = [4,3,2,1]; E = [1,2,3,4] end

# Debug graph
#------------
@present SchTri <: SchGraph begin
T::Ob
(t1,t2,t3)::Hom(T,E)
t1 ⋅ src == t2 ⋅ src
t1 ⋅ tgt == t3 ⋅ tgt
t2 ⋅ src == t3 ⋅ src
end

@acset_type Tri(SchTri)

""" e₃
2 ← 4
e₁↑ ↖ ↓ e₄
1 → 3
e₂
"""
quad = @acset Tri begin V=4; E=5; T=2;
src=[1,1,4,4,3]; tgt=[2,3,2,3,2];
t1=[1,3]; t2=[2,4]; t3=[5,5]
end

term = apex(terminal(Tri))

tri5 = @acset Tri begin
V=2; E=3; T=5; src=[1,1,2]; tgt=[2,2,2]; t1=1; t2=2; t3=3
end

tri = @acset Tri begin
V=3; E=3; T=1; src=[1,1,2]; tgt=[3,2,3]; t1=1; t2=2; t3=3
end

hs, t = debug_homomorphisms(tri, quad ⊕ tri5; monic=false)
@test length(hs) == length(homomorphisms(tri, quad ⊕ tri5))
# to_graphviz(t)

# Enumeration of subobjects
###########################
Expand Down