Skip to content

Commit c752854

Browse files
authored
Add sharding_transition_cost to getitem operators (#132)
Also update the OpSpec redistribution cost with the newly computed communication cost, so that it makes it easier for debugging with print_costs_for_node
1 parent 006fe59 commit c752854

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

autoparallel/optimize_sharding.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -310,21 +310,6 @@ def build_ds(self):
310310
self.strats[all_input_nodes[argi]] if all_input_nodes else None
311311
)
312312
for ii, comm_cost in enumerate(xxi):
313-
if argi_strat is not None:
314-
src_spec = argi_strat.strategies[ii].output_specs
315-
# TODO: operator.getitem being special is something
316-
# we might want to change in the future
317-
if node.target == operator.getitem:
318-
src_spec = src_spec[node.args[1]]
319-
tgt_spec = ssi.input_specs[argi]
320-
assert isinstance(src_spec, DTensorSpec)
321-
assert isinstance(tgt_spec, DTensorSpec)
322-
# we use our custom comm_cost function to estimate the cost
323-
# of the collective operation
324-
comm_cost = estimate_strategy_comms_cost(src_spec, tgt_spec)
325-
326-
if node in grad_param_nodes:
327-
comm_cost = comm_cost / self.rescale_grad_comm_cost_for_mp
328313
# Imagine we start node_i from S(0)S(0) and we want to reach node_{i+2} at
329314
# RR, and that node_{i+1} is an op with zero cost (like alias).
330315
# In this case, all of the following chains yield the same cost:
@@ -338,19 +323,34 @@ def build_ds(self):
338323
# in a single go. To do this, we add a tie-break cost that is 1 if a redistribution
339324
# happens prior to getting to this configuration, and 0 otherwise. This way,
340325
# we will favor having fewer redistributions happening in the graph.
341-
if argi_strat is not None and node.target != operator.getitem:
342-
original_placement = argi_strat.strategies[
343-
ii
344-
].output_specs.placements
345-
current_placement = ssi.input_specs[argi].placements
326+
if argi_strat is not None:
327+
src_spec = argi_strat.strategies[ii].output_specs
328+
# TODO: operator.getitem being special is something
329+
# we might want to change in the future
330+
if node.target == operator.getitem:
331+
src_spec = src_spec[node.args[1]]
332+
tgt_spec = ssi.input_specs[argi]
333+
assert isinstance(src_spec, DTensorSpec)
334+
assert isinstance(tgt_spec, DTensorSpec)
335+
# we use our custom comm_cost function to estimate the cost
336+
# of the collective operation
337+
comm_cost = estimate_strategy_comms_cost(src_spec, tgt_spec)
338+
346339
redistribution_happened = (
347-
current_placement != original_placement
340+
src_spec.placements != tgt_spec.placements
348341
)
349342
sharding_transition_cost = (
350343
int(redistribution_happened) * sharding_transition_scale
351344
)
352345
else:
353346
sharding_transition_cost = 0
347+
348+
if node in grad_param_nodes:
349+
comm_cost = comm_cost / self.rescale_grad_comm_cost_for_mp
350+
351+
# update OpSpec redistribution cost with our newly-computed cost
352+
# this is useful for print_costs_for_node to print the updated cost
353+
xxi[ii] = comm_cost
354354
key = (s_i, argi, ss, ii)
355355
# NOTE: this modifies ds in-place sometimes
356356
# we might want to refactor this in the future

0 commit comments

Comments
 (0)