@@ -310,21 +310,6 @@ def build_ds(self):
310310 self .strats [all_input_nodes [argi ]] if all_input_nodes else None
311311 )
312312 for ii , comm_cost in enumerate (xxi ):
313- if argi_strat is not None :
314- src_spec = argi_strat .strategies [ii ].output_specs
315- # TODO: operator.getitem being special is something
316- # we might want to change in the future
317- if node .target == operator .getitem :
318- src_spec = src_spec [node .args [1 ]]
319- tgt_spec = ssi .input_specs [argi ]
320- assert isinstance (src_spec , DTensorSpec )
321- assert isinstance (tgt_spec , DTensorSpec )
322- # we use our custom comm_cost function to estimate the cost
323- # of the collective operation
324- comm_cost = estimate_strategy_comms_cost (src_spec , tgt_spec )
325-
326- if node in grad_param_nodes :
327- comm_cost = comm_cost / self .rescale_grad_comm_cost_for_mp
328313 # Imagine we start node_i from S(0)S(0) and we want to reach node_{i+2} at
329314 # RR, and that node_{i+1} is an op with zero cost (like alias).
330315 # In this case, all of the following chains yield the same cost:
@@ -338,19 +323,34 @@ def build_ds(self):
338323 # in a single go. To do this, we add a tie-break cost that is 1 if a redistribution
339324 # happens prior to getting to this configuration, and 0 otherwise. This way,
340325 # we will favor having fewer redistributions happening in the graph.
341- if argi_strat is not None and node .target != operator .getitem :
342- original_placement = argi_strat .strategies [
343- ii
344- ].output_specs .placements
345- current_placement = ssi .input_specs [argi ].placements
326+ if argi_strat is not None :
327+ src_spec = argi_strat .strategies [ii ].output_specs
328+ # TODO: operator.getitem being special is something
329+ # we might want to change in the future
330+ if node .target == operator .getitem :
331+ src_spec = src_spec [node .args [1 ]]
332+ tgt_spec = ssi .input_specs [argi ]
333+ assert isinstance (src_spec , DTensorSpec )
334+ assert isinstance (tgt_spec , DTensorSpec )
335+ # we use our custom comm_cost function to estimate the cost
336+ # of the collective operation
337+ comm_cost = estimate_strategy_comms_cost (src_spec , tgt_spec )
338+
346339 redistribution_happened = (
347- current_placement != original_placement
340+ src_spec . placements != tgt_spec . placements
348341 )
349342 sharding_transition_cost = (
350343 int (redistribution_happened ) * sharding_transition_scale
351344 )
352345 else :
353346 sharding_transition_cost = 0
347+
348+ if node in grad_param_nodes :
349+ comm_cost = comm_cost / self .rescale_grad_comm_cost_for_mp
350+
351+ # update OpSpec redistribution cost with our newly-computed cost
352+ # this is useful for print_costs_for_node to print the updated cost
353+ xxi [ii ] = comm_cost
354354 key = (s_i , argi , ss , ii )
355355 # NOTE: this modifies ds in-place sometimes
356356 # we might want to refactor this in the future
0 commit comments