diff --git a/cadquery/assembly.py b/cadquery/assembly.py index 1aac2cb3d..1c960c11e 100644 --- a/cadquery/assembly.py +++ b/cadquery/assembly.py @@ -14,6 +14,8 @@ from typing_extensions import Literal from typish import instance_of from uuid import uuid1 as uuid +from enum import Enum +from collections import defaultdict from .cq import Workplane from .occ_impl.shapes import Shape, Compound @@ -44,6 +46,7 @@ PATH_DELIM = "/" + # entity selector grammar definition def _define_grammar(): @@ -79,6 +82,133 @@ def _define_grammar(): _grammar = _define_grammar() +class _Quantity(Enum): + POINT = 1 + AXIS = 2 + + def __repr__(self): + if self == _Quantity.POINT: + return "p" + else: + return "a" + + +class ConstraintGraph: + """ + Auxiliary structure for tracking constraint relations. + + We store an undirected graph of (object id, quantity) nodes. Every binary + constraint with 0 degrees of freedom links two property nodes. We then + traverse through the graph to solve each constraint to its local minimum. + """ + + def __init__(self, n_objs: int): + self.n_objs = n_objs + self.adjlist = defaultdict(lambda: defaultdict(dict)) + # Index of fixed objects + self.locked = set() + # Map from ordered tuples to constraints for fast solving + self.constraints = dict() + + def add_unary(self, i, pod): + """ + Unary constraint + """ + srcs, kind, param = pod + if kind == "Fixed": + self.locked.add((i, _Quantity.POINT)) + self.locked.add((i, _Quantity.AXIS)) + elif kind == "FixedPoint": + (src,) = srcs + src.SetCoord(*param) + self.locked.add((i, _Quantity.POINT)) + elif kind == "FixedAxis": + (src,) = srcs + src.SetCoord(*param) + self.locked.add((i, _Quantity.AXIS)) + + def add_binary(self, i, j, pod): + """ + Binary constraint + """ + (src, dst), kind, param = pod + + q = None + # Only keep the constraints with 0 degrees of freedom + assert kind != "Plane", "Compound constraints should not exist here" + if kind in {"PointInPlane", "PointOnLine"}: + return + elif kind == "Axis": + if param and float(param) not in {0.0, 180.0}: + return + q = _Quantity.AXIS + elif kind == "Point": + if param and param != 0.0: + return + q = _Quantity.POINT + else: + raise ValueError(f"Unknown kind {kind}") + + # Insert into adjacency lists + self.adjlist[(i, q)][(j, q)] = pod + self.adjlist[(j, q)][(i, q)] = ((dst, src), kind, param) + + def __str__(self): + def format_i(i): + return f"!{i}" if i in self.locked else str(i) + + adjlist = [ + f"{format_i(src)} -> {[format_i(j) for j in dst]}" + for src, dst in self.adjlist.items() + ] + return "\n".join(adjlist) + + @staticmethod + def zero_point_binary_constraint(pod): + (src, dst), kind, param = pod + if kind == "Axis": + if param and float(param) == 180.0: + dst.SetCoord(*(-src).Coord()) + else: + dst.SetCoord(*src.Coord()) + elif kind == "Point": + dst.SetCoord(*src.Coord()) + else: + raise ValueError(f"Illegal binary constraint {kind}") + + def solve(self, verbosity: int = 0): + remaining = {(i, key) for i, key in self.adjlist if (i, key) not in self.locked} + + def start_solve(node): + li = [node] + while li: + (i, qty) = li.pop() + if (i, qty) in remaining: + remaining.remove((i, qty)) + # Query the adjlist + for (j, p), pod in self.adjlist.get((i, qty), {}).items(): + if (j, p) in remaining: + li.append((j, p)) + else: + continue + + if verbosity > 3: + print(f"Solving {(i, qty)} -> {(j, p)} with kind {pod[1]}") + # Zero the constraint (i,qty) to (j, p) + ConstraintGraph.zero_point_binary_constraint(pod) + + # Start solving from each of the locked nodes + if verbosity > 3: + print(f"Start solving from locked: {len(self.locked)}") + for node in self.locked: + start_solve(node) + if verbosity > 3: + print(f"Remaining points: {len(remaining)}") + while remaining: + node = remaining.pop() + start_solve(node) + + class Assembly(object): """Nested assembly of Workplane and Shape objects defining their relative positions.""" @@ -346,7 +476,9 @@ def constrain( ... @overload - def constrain(self, q1: str, kind: ConstraintKind, param: Any = None) -> "Assembly": + def constrain( + self, q1: str, kind: ConstraintKind, param: Any = None + ) -> "Assembly": ... @overload @@ -363,7 +495,11 @@ def constrain( @overload def constrain( - self, id1: str, s1: Shape, kind: ConstraintKind, param: Any = None, + self, + id1: str, + s1: Shape, + kind: ConstraintKind, + param: Any = None, ) -> "Assembly": ... @@ -409,9 +545,12 @@ def constrain(self, *args, param=None): return self - def solve(self, verbosity: int = 0) -> "Assembly": + def solve(self, verbosity: int = 0, tree_initialize: bool = True) -> "Assembly": """ Solve the constraints. + + Set `tree_initialize` to true to set the constraints by tree + exploration. """ # Get all entities and number them. First entity is marked as locked @@ -453,6 +592,8 @@ def solve(self, verbosity: int = 0) -> "Assembly": locs = [self.objects[n].loc for n in ents] + cgraph = ConstraintGraph(len(ents)) + # construct the constraint mapping constraints = [] for c in self.constraints: @@ -462,6 +603,13 @@ def solve(self, verbosity: int = 0) -> "Assembly": for pod in pods: constraints.append((ixs, pod)) + if tree_initialize: + if len(ixs) == 1: + cgraph.add_unary(ixs[0], pod) + elif len(ixs) == 2: + i, j = ixs + cgraph.add_binary(i, j, pod) + # check if any constraints were specified if not constraints: raise ValueError("At least one constraint required") @@ -470,6 +618,14 @@ def solve(self, verbosity: int = 0) -> "Assembly": if len(ents) < 2: raise ValueError("At least two entities need to be constrained") + if tree_initialize: + """ + Set the locations + """ + if verbosity > 3: + print(cgraph) + cgraph.solve(verbosity) + # instantiate the solver scale = self.toCompound().BoundingBox().DiagonalLength solver = ConstraintSolver(locs, constraints, locked=locked, scale=scale) @@ -668,8 +824,12 @@ def __iter__( color = self.color if self.color else color if self.obj: - yield self.obj if isinstance(self.obj, Shape) else Compound.makeCompound( - s for s in self.obj.vals() if isinstance(s, Shape) + yield ( + self.obj + if isinstance(self.obj, Shape) + else Compound.makeCompound( + s for s in self.obj.vals() if isinstance(s, Shape) + ) ), name, loc, color for ch in self.children: