diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 49f6d020..045b8dff 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -628,23 +628,45 @@ def get_log(self, colored=False): attr_color = _identity l_id = 1 plc_txt = txt_color("# placement=") - cost_txt = txt_color(", cost=") + cost_txt = txt_color(", compute_cost=") for node in nodes: if node.op == "output": continue d = opt[node] + d_inputs = [opt[n] for n in self._all_input_nodes(node)] + redistributions = [] + for i, x in enumerate(d_inputs): + src_strat = x[0]["out_strat"] + if not isinstance(src_strat, DTensorSpec): + src_strat = src_strat[i] + dst_strat = d[i]["inp_strat"] + if src_strat != dst_strat: + redistributions.append( + ( + i, + self._all_input_nodes(node)[i], + src_strat, + dst_strat, + ) + ) + # this is an annoying special case + if node.target == operator.getitem: + redistributions = [] + preline = "" + if redistributions: + preline += f"\n # Redistributing for node {str(node)}:" + for i, n, src, dst in redistributions: + src_s = "".join(str(p) for p in src.placements) + dst_s = "".join(str(p) for p in dst.placements) + comment = f"# shape={str(tuple(src.shape))}, comm_cost={d[i]['comm_cost']}" + preline += f"\n # {str(n)} = redistribute({str(n)}, src={src_s}, dst={dst_s}) {comment}" strat = str(d[0]["full_strat"]) - costs = [ - (x["comm_cost"], x["compute_cost"], x["sharding_transition_cost"]) - for x in d - ] + costs = sum(x["compute_cost"] for x in d) device_order = getattr(node.meta, "device_order", None) + device_order_str = "" if device_order: - line = f" {plc_txt}{attr_color(strat)} device_order: {device_order} {cost_txt}{attr_color(str(costs))}" - else: - line = ( - f" {plc_txt}{attr_color(strat)} {cost_txt}{attr_color(str(costs))}" - ) + device_order_str = f" device_order={device_order} " + line = f" {plc_txt}{attr_color(strat)}{device_order_str} {cost_txt}{attr_color(str(costs))}" if node.op == "placeholder": line = f" # {node.name}: {line}" code.insert(l_id, line) @@ -654,6 +676,9 @@ def get_log(self, colored=False): while not code[l_id].lstrip().startswith(repr(node)): l_id += 1 code[l_id] += line + if preline: + code.insert(l_id, preline) + l_id += 1 l_id += 1 code = "\n".join(code) total_cost = sum(self.ds[x]["cost"] for x in self.res)