@@ -745,14 +745,16 @@ def get_solution(self, verbose=False):
745745 # TODO: assert all nodes have a placement?
746746 return opt
747747
748- def _add_node_constraint (self , node , oi , constraint_name = None ):
748+ def _add_node_constraint (
749+ self , node , output_constraint_indices , constraint_name = None
750+ ):
749751 if constraint_name is None :
750752 constraint_name = "user_constraint"
751753 s_i = self .node_map [node ]
752754 vars_per_arg = {}
753- for argi , oi_ , ii in self .walk_over_options (node ):
754- if oi_ == oi :
755- va = self .ds [(s_i , argi , oi , ii )]["va" ]
755+ for argi , output_constraint_index , input_index in self .walk_over_options (node ):
756+ if output_constraint_index in output_constraint_indices :
757+ va = self .ds [(s_i , argi , output_constraint_index , input_index )]["va" ]
756758 vars_per_arg .setdefault (argi , []).append (va )
757759 for eqs in vars_per_arg .values ():
758760 self .prob += (pulp .lpSum (eqs ) == 1 , _get_next_name (constraint_name ))
@@ -834,15 +836,20 @@ def add_node_constraint(self, node, placement=None, constraint_name=None):
834836 if placement is None :
835837 # default is Shard(0) to parallelize on the batch
836838 placement = (Shard (0 ),) + (Replicate (),) * (self .mesh .ndim - 1 )
837- for oi , s in enumerate (strat .strategies ):
839+ output_constraint_indices = []
840+ for output_constraint_index , s in enumerate (strat .strategies ):
838841 spec = s .output_specs
839842 if spec .placements == placement :
840- break
841- else :
843+ output_constraint_indices . append ( output_constraint_index )
844+ if len ( output_constraint_indices ) == 0 :
842845 raise RuntimeError (
843846 f"Couldn't find appropriate constraint { node } { constraint_name } { placement } "
844847 )
845- self ._add_node_constraint (node , oi = oi , constraint_name = constraint_name )
848+ self ._add_node_constraint (
849+ node ,
850+ output_constraint_indices = output_constraint_indices ,
851+ constraint_name = constraint_name ,
852+ )
846853
847854 def add_sharded_input_constraint (
848855 self , input_placements : Optional [list [Optional [tuple [Placement , ...]]]] = None
0 commit comments