Skip to content

Commit 90332a3

Browse files
authored
Don't stop at first OpSpec that satisfies output constraint (#131)
* Don't take first OpSpec that satisfies output placement * Cleanup variable names
1 parent c752854 commit 90332a3

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

autoparallel/optimize_sharding.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -745,14 +745,16 @@ def get_solution(self, verbose=False):
745745
# TODO: assert all nodes have a placement?
746746
return opt
747747

748-
def _add_node_constraint(self, node, oi, constraint_name=None):
748+
def _add_node_constraint(
749+
self, node, output_constraint_indices, constraint_name=None
750+
):
749751
if constraint_name is None:
750752
constraint_name = "user_constraint"
751753
s_i = self.node_map[node]
752754
vars_per_arg = {}
753-
for argi, oi_, ii in self.walk_over_options(node):
754-
if oi_ == oi:
755-
va = self.ds[(s_i, argi, oi, ii)]["va"]
755+
for argi, output_constraint_index, input_index in self.walk_over_options(node):
756+
if output_constraint_index in output_constraint_indices:
757+
va = self.ds[(s_i, argi, output_constraint_index, input_index)]["va"]
756758
vars_per_arg.setdefault(argi, []).append(va)
757759
for eqs in vars_per_arg.values():
758760
self.prob += (pulp.lpSum(eqs) == 1, _get_next_name(constraint_name))
@@ -834,15 +836,20 @@ def add_node_constraint(self, node, placement=None, constraint_name=None):
834836
if placement is None:
835837
# default is Shard(0) to parallelize on the batch
836838
placement = (Shard(0),) + (Replicate(),) * (self.mesh.ndim - 1)
837-
for oi, s in enumerate(strat.strategies):
839+
output_constraint_indices = []
840+
for output_constraint_index, s in enumerate(strat.strategies):
838841
spec = s.output_specs
839842
if spec.placements == placement:
840-
break
841-
else:
843+
output_constraint_indices.append(output_constraint_index)
844+
if len(output_constraint_indices) == 0:
842845
raise RuntimeError(
843846
f"Couldn't find appropriate constraint {node} {constraint_name} {placement}"
844847
)
845-
self._add_node_constraint(node, oi=oi, constraint_name=constraint_name)
848+
self._add_node_constraint(
849+
node,
850+
output_constraint_indices=output_constraint_indices,
851+
constraint_name=constraint_name,
852+
)
846853

847854
def add_sharded_input_constraint(
848855
self, input_placements: Optional[list[Optional[tuple[Placement, ...]]]] = None

0 commit comments

Comments
 (0)