From a6bb92992f7df0c6d4725db5f94e1b5d16478137 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 1 Sep 2025 15:09:00 +0000 Subject: [PATCH 1/2] Make it easier to see when redistributions happen in get_log --- autoparallel/optimize_sharding.py | 42 +++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 49f6d020..5c475ee2 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -628,16 +628,45 @@ def get_log(self, colored=False): attr_color = _identity l_id = 1 plc_txt = txt_color("# placement=") - cost_txt = txt_color(", cost=") + cost_txt = txt_color(", compute_cost=") for node in nodes: if node.op == "output": continue d = opt[node] + d_inputs = [opt[n] for n in self._all_input_nodes(node)] + redistributions = [] + for i, x in enumerate(d_inputs): + src_strat = x[0]["out_strat"] + if not isinstance(src_strat, DTensorSpec): + src_strat = src_strat[i] + dst_strat = d[i]["inp_strat"] + if src_strat != dst_strat: + redistributions.append( + ( + i, + self._all_input_nodes(node)[i], + src_strat, + dst_strat, + ) + ) + # this is an annoying special case + if node.target == operator.getitem: + idx = node.args[1] + src_strat = d_inputs[0][0]["out_strat"][idx] + dst_strat = d[0]["inp_strat"] + redistributions = [] + # redistributions.append((0, node.args[0], src_strat, dst_strat)) + preline = "" + if redistributions: + preline += f"\n # Redistributing for node {str(node)}:" + for i, n, src, dst in redistributions: + src_s = "".join(str(p) for p in src.placements) + dst_s = "".join(str(p) for p in dst.placements) + comment = f"# shape={str(tuple(src.shape))}, comm_cost={d[i]['comm_cost']}" + preline += f"\n # {str(n)} = redistribute({str(n)}, src={src_s}, dst={dst_s}) {comment}" strat = str(d[0]["full_strat"]) - costs = [ - (x["comm_cost"], x["compute_cost"], x["sharding_transition_cost"]) - for x in d - ] + costs = [x["compute_cost"] for x in d] + costs = sum(costs) device_order = getattr(node.meta, "device_order", None) if device_order: line = f" {plc_txt}{attr_color(strat)} device_order: {device_order} {cost_txt}{attr_color(str(costs))}" @@ -654,6 +683,9 @@ def get_log(self, colored=False): while not code[l_id].lstrip().startswith(repr(node)): l_id += 1 code[l_id] += line + if preline: + code.insert(l_id, preline) + l_id += 1 l_id += 1 code = "\n".join(code) total_cost = sum(self.ds[x]["cost"] for x in self.res) From d5fd476c2b2f3d046036a80b845d0f1b44478cd5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 1 Sep 2025 15:12:44 +0000 Subject: [PATCH 2/2] Minor cleanups --- autoparallel/optimize_sharding.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 5c475ee2..045b8dff 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -651,11 +651,7 @@ def get_log(self, colored=False): ) # this is an annoying special case if node.target == operator.getitem: - idx = node.args[1] - src_strat = d_inputs[0][0]["out_strat"][idx] - dst_strat = d[0]["inp_strat"] redistributions = [] - # redistributions.append((0, node.args[0], src_strat, dst_strat)) preline = "" if redistributions: preline += f"\n # Redistributing for node {str(node)}:" @@ -665,15 +661,12 @@ def get_log(self, colored=False): comment = f"# shape={str(tuple(src.shape))}, comm_cost={d[i]['comm_cost']}" preline += f"\n # {str(n)} = redistribute({str(n)}, src={src_s}, dst={dst_s}) {comment}" strat = str(d[0]["full_strat"]) - costs = [x["compute_cost"] for x in d] - costs = sum(costs) + costs = sum(x["compute_cost"] for x in d) device_order = getattr(node.meta, "device_order", None) + device_order_str = "" if device_order: - line = f" {plc_txt}{attr_color(strat)} device_order: {device_order} {cost_txt}{attr_color(str(costs))}" - else: - line = ( - f" {plc_txt}{attr_color(strat)} {cost_txt}{attr_color(str(costs))}" - ) + device_order_str = f" device_order={device_order} " + line = f" {plc_txt}{attr_color(strat)}{device_order_str} {cost_txt}{attr_color(str(costs))}" if node.op == "placeholder": line = f" # {node.name}: {line}" code.insert(l_id, line)