Skip to content

Commit 006fe59

Browse files
authored
Fix clustering key to account for input node placements as well (#140)
Previously, we were only taking the current node OpSpec string to hash it. This includes the supported input/output placements, but doesn't account for its input node, which might have different output shardings. This PR fixes it, and also add an assert in the optimize_sharding to validate that the created links are consistent
1 parent 986c922 commit 006fe59

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

autoparallel/graph_clustering.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
tree_flatten,
2727
)
2828
from torch._inductor.codecache import sha256_hash
29+
from torch.distributed.tensor._dtensor_spec import DTensorSpec
2930
from torch.distributed.tensor._op_schema import OpStrategy
3031

3132
logger: logging.Logger = logging.getLogger(__name__)
@@ -52,7 +53,24 @@ def _normalize_args(
5253
return (sorted_keys, tuple(_extract_args(arg) for arg in all_args))
5354

5455

55-
def _prepare_op_strategy(op_strategy):
56+
def _print_output_specs(op_strategy):
57+
output = []
58+
for s in op_strategy.strategies:
59+
output_placements = []
60+
output_specs = s.output_specs
61+
if isinstance(output_specs, DTensorSpec):
62+
output_specs = [output_specs]
63+
for output_spec in output_specs:
64+
if output_spec is None:
65+
output_placements.append("(None)")
66+
continue
67+
plc_str = ",".join([str(p) for p in output_spec.placements])
68+
output_placements.append(f"({plc_str})")
69+
output.append(f"({','.join(output_placements)})")
70+
return ", ".join(output)
71+
72+
73+
def _prepare_op_strategy(op_strategy, output_only=False):
5674
# hasing op_strategy is expensive, so we hash the string representation
5775
# instead, which is much cheaper and is a reasonable proxy for the
5876
# clustering
@@ -62,14 +80,20 @@ def _prepare_op_strategy(op_strategy):
6280
# view ops, which propagate the input shardings to the output.
6381
# So we also add the strategy for a node as a hash key to avoid
6482
# clustering nodes that look the same but have different strategies
83+
if output_only:
84+
return _print_output_specs(op_strategy)
6585
return str(op_strategy)
6686

6787

68-
def _hash_node(node, op_strategy, input_pickler):
88+
def _hash_node(node, strategies, input_pickler):
6989
key = (
7090
node.meta.get("stack_trace"),
7191
_normalize_args(node),
72-
_prepare_op_strategy(op_strategy),
92+
_prepare_op_strategy(strategies[node]),
93+
tuple(
94+
_prepare_op_strategy(strategies[s], output_only=True)
95+
for s in node.all_input_nodes
96+
),
7397
)
7498
return sha256_hash(input_pickler.dumps(key))
7599

@@ -104,9 +128,7 @@ def get_identical_regions(
104128
if node.op == "placeholder":
105129
continue
106130

107-
duplicates = hash_to_duplicates[
108-
_hash_node(node, strategies[node], input_pickler)
109-
]
131+
duplicates = hash_to_duplicates[_hash_node(node, strategies, input_pickler)]
110132
duplicates.append(node)
111133
node_to_duplicates[node] = duplicates
112134
logger.info(f"Hashed nodes in {time.time() - t} s")

autoparallel/optimize_sharding.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,13 @@ def create_cluster_links(self, clusters):
218218
for n0, ni in zip(cluster0, cluster_i):
219219
s0 = self.node_map[n0]
220220
s1 = self.node_map[ni]
221-
for argi, oi, ii in self.walk_over_options(n0):
221+
options_n0 = list(self.walk_over_options(n0))
222+
options_ni = list(self.walk_over_options(ni))
223+
assert options_n0 == options_ni, (
224+
f"Problem with graph clustering: {n0} and {ni} don't have the same number "
225+
"of input/output placements. Please report a bug"
226+
)
227+
for argi, oi, ii in options_n0:
222228
self.cluster_links[(s1, argi, oi, ii)] = (s0, argi, oi, ii)
223229

224230
def _build_pulp_variable(self, key, ds):
@@ -475,7 +481,7 @@ def add_output_input_consistent_constraint(self):
475481
va = self.ds[key]["va"]
476482
vars_s_j.setdefault(s_j_ii, []).append(va)
477483

478-
if vars_s_i.keys() != vars_s_j.keys():
484+
if len(vars_s_j) == 0:
479485
vars_s_j = {}
480486
for _, s_j_oi, s_j_ii in self.walk_over_options(user, argj):
481487
key = (s_j, argj, s_j_oi, s_j_ii)
@@ -485,7 +491,7 @@ def add_output_input_consistent_constraint(self):
485491
va = self.ds[key]["va"]
486492
vars_s_j.setdefault(s_j_ii, []).append(va)
487493

488-
if vars_s_i.keys() != vars_s_j.keys():
494+
if len(vars_s_i) == 0:
489495
vars_s_i = {}
490496
for _, s_i_oi, s_i_ii in self.walk_over_options(node, argi):
491497
key = (s_i, argi, s_i_oi, s_i_ii)

0 commit comments

Comments
 (0)