Skip to content

Commit a3f182b

Browse files
authored
simplify weighted (#17)
* simplify weighted * automatically general rotated and reflected gadgets * fix simplifier gadgets
1 parent aa03be1 commit a3f182b

File tree

7 files changed

+131
-135
lines changed

7 files changed

+131
-135
lines changed

src/UnitDiskMapping.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ export UNode, contract_graph, compress_graph
1212

1313
include("utils.jl")
1414
include("gadgets.jl")
15-
include("simplifiers.jl")
1615
include("mapping.jl")
1716
include("weighted.jl")
17+
include("simplifiers.jl")
1818
include("extracting_results.jl")
1919
include("pathdecomposition/pathdecomposition.jl")
2020
#include("shrinking/compressUDG.jl")

src/gadgets.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -386,16 +386,17 @@ for T in [:RotatedGadget, :ReflectedGadget]
386386
@eval mis_overhead(p::$T) = mis_overhead(p.gadget)
387387
end
388388

389-
function _apply_transform(r::RotatedGadget, node::Node, center)
390-
loc = getxy(node)
389+
for T in [:RotatedGadget, :ReflectedGadget]
390+
@eval _apply_transform(r::$T, node::Node, center) = chxy(node, _apply_transform(r, getxy(node), center))
391+
end
392+
function _apply_transform(r::RotatedGadget, loc::Tuple{Int,Int}, center)
391393
for _=1:r.n
392394
loc = rotate90(loc, center)
393395
end
394-
return chxy(node, loc)
396+
return loc
395397
end
396398

397-
function _apply_transform(r::ReflectedGadget, node::Node, center)
398-
loc = getxy(node)
399+
function _apply_transform(r::ReflectedGadget, loc::Tuple{Int,Int}, center)
399400
loc = if r.mirror == "x"
400401
reflectx(loc, center)
401402
elseif r.mirror == "y"
@@ -407,7 +408,7 @@ function _apply_transform(r::ReflectedGadget, node::Node, center)
407408
else
408409
throw(ArgumentError("reflection direction $(r.direction) is not defined!"))
409410
end
410-
chxy(node, loc)
411+
return loc
411412
end
412413

413414
export vertex_overhead
@@ -429,3 +430,17 @@ function _boundary_config(pins, config)
429430
end
430431
return res
431432
end
433+
434+
export rotated_and_reflected
435+
function rotated_and_reflected(p::Pattern)
436+
patterns = Pattern[p]
437+
source_matrices = [source_matrix(p)]
438+
for pi in [[RotatedGadget(p, i) for i=1:3]..., [ReflectedGadget(p, axis) for axis in ["x", "y", "diag", "offdiag"]]...]
439+
m = source_matrix(pi)
440+
if m source_matrices
441+
push!(patterns, pi)
442+
push!(source_matrices, m)
443+
end
444+
end
445+
return patterns
446+
end

src/mapping.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,10 @@ It can be a vector or one of the following inputs
373373
374374
Returns a `MappingResult` instance.
375375
"""
376-
map_graph(g::SimpleGraph; vertex_order=Greedy(), ruleset=[RotatedGadget(DanglingLeg(), n) for n=0:3]) = map_graph(UnWeighted(), g; ruleset=ruleset, vertex_order=vertex_order)
377-
function map_graph(mode, g::SimpleGraph; vertex_order=Greedy(), ruleset=[RotatedGadget(DanglingLeg(), n) for n=0:3])
376+
function map_graph(g::SimpleGraph; vertex_order=Greedy(), ruleset=default_simplifier_ruleset(UnWeighted()))
377+
map_graph(UnWeighted(), g; ruleset=ruleset, vertex_order=vertex_order)
378+
end
379+
function map_graph(mode, g::SimpleGraph; vertex_order=Greedy(), ruleset=default_simplifier_ruleset(mode))
378380
ug = embed_graph(mode, g; vertex_order=vertex_order)
379381
mis_overhead0 = mis_overhead_copylines(ug)
380382
ug, tape = apply_crossing_gadgets!(mode, ug)
@@ -385,3 +387,5 @@ function map_graph(mode, g::SimpleGraph; vertex_order=Greedy(), ruleset=[Rotated
385387
end
386388

387389
map_configs_back(r::MappingResult{<:Cell}, configs::AbstractVector) = unapply_gadgets!(copy(r.grid_graph), r.mapping_history, copy.(configs))[2]
390+
default_simplifier_ruleset(::UnWeighted) = vcat([rotated_and_reflected(rule) for rule in simplifier_ruleset]...)
391+
default_simplifier_ruleset(::Weighted) = weighted.(default_simplifier_ruleset(UnWeighted()))

src/simplifiers.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,41 +32,51 @@ struct GridGraph{NT<:Node}
3232
end
3333
vertices_on_boundary(gg::GridGraph) = vertices_on_boundary(gg.nodes, gg.size...)
3434

35-
function gridgraphfromstring(str::String)
36-
item_array = Vector{Bool}[]
35+
function gridgraphfromstring(mode::Union{Weighted, UnWeighted}, str::String)
36+
item_array = Vector{Int}[]
3737
for line in split(str, "\n")
38-
list = [item ("o", "") ? true : (@assert item (".", ""); false) for item in split(line, " ") if !isempty(item)]
38+
items = [item for item in split(line, " ") if !isempty(item)]
39+
list = if mode isa Weighted # TODO: the weighted version need to be tested! Consider removing it!
40+
@assert all(item->item (".", "", "@", "", "o", ""), items)
41+
[item ("@", "") ? 2 : (item ("o", "") ? 1 : 0) for item in items]
42+
else
43+
@assert all(item->item (".", "", "@", ""), items)
44+
[item ("@", "") ? 1 : 0 for item in items]
45+
end
3946
if !isempty(list)
4047
push!(item_array, list)
4148
end
4249
end
4350
@assert all(==(length(item_array[1])), length.(item_array))
4451
mat = hcat(item_array...)'
45-
locs = [SimpleNode(ci.I) for ci in findall(mat)]
52+
locs = [_to_node(mode, ci.I, mat[ci]) for ci in findall(!iszero, mat)]
4653
return GridGraph(size(mat), locs)
4754
end
55+
_to_node(::UnWeighted, loc::Tuple{Int,Int}, w::Int) = SimpleNode(loc...)
56+
_to_node(::Weighted, loc::Tuple{Int,Int}, w::Int) = WeightedNode(loc..., w)
4857

49-
const simplifier_ruleset = SimplifyPattern[]
50-
51-
macro gg(expr)
58+
function gg_func(mode, expr)
5259
@assert expr.head == :(=)
5360
name = expr.args[1]
5461
pair = expr.args[2]
5562
@assert pair.head == :(call) && pair.args[1] == :(=>)
56-
g1 = gridgraphfromstring(pair.args[2])
57-
g2 = gridgraphfromstring(pair.args[3])
63+
g1 = gridgraphfromstring(mode, pair.args[2])
64+
g2 = gridgraphfromstring(mode, pair.args[3])
5865
@assert g1.size == g2.size
5966
@assert g1.nodes[vertices_on_boundary(g1)] == g2.nodes[vertices_on_boundary(g2)]
6067
return quote
6168
struct $(esc(name)) <: SimplifyPattern end
6269
Base.size(::$(esc(name))) = $(g1.size)
6370
$UnitDiskMapping.source_locations(::$(esc(name))) = $(g1.nodes)
6471
$UnitDiskMapping.mapped_locations(::$(esc(name))) = $(g2.nodes)
65-
push!($(simplifier_ruleset), $(esc(name))())
6672
$(esc(name))
6773
end
6874
end
6975

76+
macro gg(expr)
77+
gg_func(UnWeighted(), expr)
78+
end
79+
7080
# # How to add a new simplification rule
7181
# 1. specify a gadget like the following. Use either `o` and `●` to specify a vertex,
7282
# either `.` or `⋅` to specify a placeholder.
@@ -83,5 +93,11 @@ end
8393
⋅ ● ⋅
8494
"""
8595

86-
# 2. run the script `project/createmap` to generate `mis_overhead` and other informations required
96+
# 2. add your gadget to simplifier ruleset.
97+
const simplifier_ruleset = SimplifyPattern[DanglingLeg()]
98+
# set centers (vertices with weight 1) for the weighted version
99+
source_centers(::WeightedGadget{DanglingLeg}) = [(2,2)]
100+
mapped_centers(::WeightedGadget{DanglingLeg}) = [(4,2)]
101+
102+
# 3. run the script `project/createmap` to generate `mis_overhead` and other informations required
87103
# for mapping back. (Note: will overwrite the source file `src/extracting_results.jl`)

src/weighted.jl

Lines changed: 45 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
export WeightedCell, WeightedGadget, WeightedNode
12
# TODO:
23
# * add path decomposition
34
struct WeightedCell{RT} <: AbstractCell
@@ -7,13 +8,10 @@ struct WeightedCell{RT} <: AbstractCell
78
weight::RT
89
end
910

10-
abstract type WeightedCrossPattern <: Pattern end
11-
abstract type WeightedSimplifyPattern <:Pattern end
1211
struct WeightedGadget{GT} <: Pattern
1312
gadget::GT
14-
factor::Int
1513
end
16-
const WeightedPattern = Union{WeightedCrossPattern, WeightedSimplifyPattern, WeightedGadget}
14+
const WeightedGadgetTypes = Union{WeightedGadget, RotatedGadget{<:WeightedGadget}, ReflectedGadget{<:WeightedGadget}}
1715

1816
Base.isempty(cell::WeightedCell) = !cell.occupied
1917
Base.empty(::Type{WeightedCell{RT}}) where RT = WeightedCell(false, false, false,0)
@@ -62,143 +60,76 @@ _weight2(::CopyLine{Weighted}, i, j) = WeightedNode(i, j, 2)
6260
_weight1(::CopyLine{Weighted}, i, j) = WeightedNode(i, j, 1)
6361
_cell_type(::Type{<:WeightedNode}) = WeightedCell{Int}
6462

65-
function source_graph(r::WeightedGadget)
66-
locs, g, pins = source_graph(r.gadget)
67-
_mul_weight.(locs, r.factor), g, pins
68-
end
69-
function mapped_graph(r::WeightedGadget)
70-
locs, g, pins = mapped_graph(r.gadget)
71-
_mul_weight.(locs, r.factor), g, pins
72-
end
73-
_mul_weight(node::SimpleNode, factor) = WeightedNode(node..., factor)
74-
mis_overhead(p::WeightedGadget) = 2*mis_overhead(p.gadget)
75-
76-
# new gadgets
77-
struct WeightedWTurn <: WeightedCrossPattern end
78-
# ⋅ ⋅ ⋅ ⋅
79-
# ⋅ ⋅ ◯ ●
80-
# ⋅ ● ● ⋅
81-
# ⋅ ● ⋅ ⋅
82-
83-
# ⋅ ⋅ ⋅ ⋅
84-
# ⋅ ⋅ ⋅ ●
85-
# ⋅ ⋅ ◯ ⋅
86-
# ⋅ ● ⋅ ⋅
87-
88-
struct WeightedBranchFix <: WeightedCrossPattern end
89-
# ⋅ ● ⋅ ⋅
90-
# ⋅ ● ◯ ⋅
91-
# ⋅ ● ● ⋅
92-
# ⋅ ● ⋅ ⋅
93-
94-
# ⋅ ● ⋅ ⋅
95-
# ⋅ ● ⋅ ⋅
96-
# ⋅ ◯ ⋅ ⋅
97-
# ⋅ ● ⋅ ⋅
98-
99-
struct WeightedTurn <: WeightedCrossPattern end
100-
# ⋅ ● ⋅ ⋅
101-
# ⋅ ● ⋅ ⋅
102-
# ⋅ ● ◯ ●
103-
# ⋅ ⋅ ⋅ ⋅
104-
105-
# ⋅ ● ⋅ ⋅
106-
# ⋅ ⋅ ◯ ⋅
107-
# ⋅ ⋅ ⋅ ●
108-
# ⋅ ⋅ ⋅ ⋅
109-
110-
struct WeightedBranch <: WeightedCrossPattern end
111-
# ⋅ ● ⋅ ⋅
112-
# ⋅ ● ⋅ ⋅
113-
# ⋅ ● ◯ ●
114-
# ⋅ ● ● ⋅
115-
# ⋅ ● ⋅ ⋅
116-
117-
118-
# ⋅ ● ⋅ ⋅ ?
119-
# ⋅ ⋅ ◯ ⋅
120-
# ⋅ ● ⋅ ●
121-
# ⋅ ⋅ ● ⋅
122-
# ⋅ ● ⋅ ⋅
123-
124-
struct WeightedBranchFixB <: WeightedCrossPattern end
125-
# ⋅ ⋅ ⋅ ⋅
126-
# ⋅ ⋅ ◯ ⋅
127-
# ⋅ ● ● ⋅
128-
# ⋅ ● ⋅ ⋅
129-
130-
# ⋅ ⋅ ⋅ ⋅
131-
# ⋅ ⋅ ⋅ ⋅
132-
# ⋅ ◯ ⋅ ⋅
133-
# ⋅ ● ⋅ ⋅
134-
135-
struct WeightedEndTurn <: WeightedCrossPattern end
136-
# ⋅ ● ⋅ ⋅
137-
# ⋅ ● ◯ ⋅
138-
# ⋅ ⋅ ⋅ ⋅
139-
140-
# ⋅ ◯ ⋅ ⋅
141-
# ⋅ ⋅ ⋅ ⋅
142-
# ⋅ ⋅ ⋅ ⋅
143-
144-
for T in [:Cross, :TrivialTurn, :TCon]
145-
@eval weighted(c::$T) = WeightedGadget(c, 2)
146-
end
63+
weighted(c::Pattern) = WeightedGadget(c)
14764
unweighted(w::WeightedGadget) = w.gadget
14865
weighted(r::RotatedGadget) = RotatedGadget(weighted(r.gadget), r.n)
14966
weighted(r::ReflectedGadget) = ReflectedGadget(weighted(r.gadget), r.mirror)
15067
unweighted(r::RotatedGadget) = RotatedGadget(unweighted(r.gadget), r.n)
15168
unweighted(r::ReflectedGadget) = ReflectedGadget(unweighted(r.gadget), r.mirror)
69+
mis_overhead(w::WeightedGadget) = mis_overhead(w.gadget) * 2
15270

153-
for T in [:Turn, :Branch, :BranchFix, :BranchFixB, :WTurn, :EndTurn]
154-
WT = Symbol(:Weighted, T)
155-
@eval weighted(::$T) = $WT()
156-
@eval unweighted(::$WT) = $T()
157-
@eval mis_overhead(::$WT) = mis_overhead($T()) * 2
71+
function source_graph(r::WeightedGadget)
72+
raw = unweighted(r)
73+
locs, g, pins = source_graph(raw)
74+
return map(loc->_mul_weight(loc, getxy(loc) source_centers(r) ? 1 : 2), locs), g, pins
75+
end
76+
function mapped_graph(r::WeightedGadget)
77+
raw = unweighted(r)
78+
locs, g, pins = mapped_graph(raw)
79+
return map(loc->_mul_weight(loc, getxy(loc) mapped_centers(r) ? 1 : 2), locs), g, pins
15880
end
81+
_mul_weight(node::SimpleNode, factor) = WeightedNode(node..., factor)
15982

16083
for (T, centerloc) in [(:Turn, (2, 3)), (:Branch, (2, 3)), (:BranchFix, (3, 2)), (:BranchFixB, (3, 2)), (:WTurn, (3, 3)), (:EndTurn, (1, 2))]
161-
WT = Symbol(:Weighted, T)
162-
@eval function source_graph(r::$WT)
163-
raw = unweighted(r)
164-
locs, g, pins = source_graph(raw)
165-
return map(loc->_mul_weight(loc, loc == SimpleNode(cross_location(raw) .+ (0, 1)) ? 1 : 2), locs), g, pins
166-
end
167-
@eval function mapped_graph(r::$WT)
168-
raw = unweighted(r)
169-
locs, g, pins = mapped_graph(raw)
170-
return map(loc->_mul_weight(loc, loc == SimpleNode($centerloc) ? 1 : 2), locs), g, pins
84+
@eval source_centers(::WeightedGadget{<:$T}) = [cross_location($T()) .+ (0, 1)]
85+
@eval mapped_centers(::WeightedGadget{<:$T}) = [$centerloc]
86+
end
87+
# default to having no source center!
88+
source_centers(::WeightedGadget) = Tuple{Int,Int}[]
89+
mapped_centers(::WeightedGadget) = Tuple{Int,Int}[]
90+
for T in [:(RotatedGadget{<:WeightedGadget}), :(ReflectedGadget{<:WeightedGadget})]
91+
@eval function source_centers(r::$T)
92+
cross = cross_location(r.gadget)
93+
return map(loc->loc .+ _get_offset(r), _apply_transform.(Ref(r), source_centers(r.gadget), Ref(cross)))
17194
end
172-
@eval function move_center(::$WT, nodexy)
173-
nodexy .+ $centerloc .- (cross_location($WT()) .+ (0, 1))
95+
@eval function mapped_centers(r::$T)
96+
cross = cross_location(r.gadget)
97+
return map(loc->loc .+ _get_offset(r), _apply_transform.(Ref(r), mapped_centers(r.gadget), Ref(cross)))
17498
end
17599
end
176-
move_center(::Pattern, nodexy) = nodexy
177100

178-
for T in [:WeightedCrossPattern, :WeightedGadget]
179-
@eval Base.size(r::$T) = size(unweighted(r))
180-
@eval cross_location(r::$T) = cross_location(unweighted(r))
181-
@eval iscon(r::$T) = iscon(unweighted(r))
182-
@eval connected_nodes(r::$T) = connected_nodes(unweighted(r))
183-
@eval vertex_overhead(r::$T) = vertex_overhead(unweighted(r))
184-
end
101+
Base.size(r::WeightedGadget) = size(unweighted(r))
102+
cross_location(r::WeightedGadget) = cross_location(unweighted(r))
103+
iscon(r::WeightedGadget) = iscon(unweighted(r))
104+
connected_nodes(r::WeightedGadget) = connected_nodes(unweighted(r))
105+
vertex_overhead(r::WeightedGadget) = vertex_overhead(unweighted(r))
185106

186107
const crossing_ruleset_weighted = weighted.(crossing_ruleset)
187108
get_ruleset(::Weighted) = crossing_ruleset_weighted
188109

189110
export get_weights
190111
get_weights(ug::UGrid) = [ug.content[ci...].weight for ci in coordinates(ug)]
191112

113+
# mapping configurations back
192114
export trace_centers
115+
function move_center(w::WeightedGadgetTypes, nodexy, offset)
116+
for (sc, mc) in zip(source_centers(w), mapped_centers(w))
117+
if offset == sc
118+
return nodexy .+ mc .- sc # found
119+
end
120+
end
121+
error("center not found, source center = $(source_centers(w)), while offset = $(offset)")
122+
end
123+
193124
trace_centers(r::MappingResult) = trace_centers(r.grid_graph, r.mapping_history)
194125
function trace_centers(ug::UGrid, tape)
195126
center_locations = map(x->center_location(x; padding=ug.padding) .+ (0, 1), ug.lines)
196127
for (gadget, i, j) in tape
197128
m, n = size(gadget)
198129
for (k, centerloc) in enumerate(center_locations)
199-
offset = centerloc .- (i,j)
200-
if 0<=offset[1] <= m-1 && 0<=offset[2] <= n-1
201-
center_locations[k] = move_center(gadget, centerloc)
130+
offset = centerloc .- (i-1,j-1)
131+
if 1<=offset[1] <= m && 1<=offset[2] <= n
132+
center_locations[k] = move_center(gadget, centerloc, offset)
202133
end
203134
end
204135
end

test/gadgets.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,11 @@ using Graphs
1616
@test diff == -mis_overhead(s)
1717
@test sig
1818
end
19+
end
20+
21+
@testset "rotated_and_reflected" begin
22+
@test length(rotated_and_reflected(UnitDiskMapping.DanglingLeg())) == 4
23+
@test length(rotated_and_reflected(Cross{false}())) == 4
24+
@test length(rotated_and_reflected(Cross{true}())) == 4
25+
@test length(rotated_and_reflected(BranchFixB())) == 8
1926
end

0 commit comments

Comments
 (0)