diff --git a/autoparallel/api.py b/autoparallel/api.py index ce087738..8df62a68 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -180,6 +180,7 @@ def __init__( enable_ac: bool = True, # None means 'auto' ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto", + enable_asynctp: bool = False, **kwargs, ): self.stack = ExitStack() @@ -210,6 +211,8 @@ def __init__( self.enable_ac = enable_ac self.ac_stage_size_in_GiB = ac_stage_size_in_GiB + self.enable_asynctp = enable_asynctp + # NB: rest of the construction happens in __enter__ self.active = False @@ -236,6 +239,7 @@ def __enter__(self): self.mesh, rescale_grad_comm_cost_for_mp, repeated_subgraphs=self.kwargs.get("repeated_subgraphs", False), + enable_asynctp=self.enable_asynctp, ) # makes sharding of params and gradients the same diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 691ec9e1..8831b991 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -105,6 +105,8 @@ from .propagation_rules import _create_all_options from .utils import get_local_map_placement_option, get_placement_options +aten = torch.ops.aten + def _debug_node(node): def my_print(x): @@ -128,7 +130,12 @@ def _get_next_name(name): class ShardingOptimizer: def __init__( - self, gm, mesh, rescale_grad_comm_cost_for_mp=1.0, repeated_subgraphs=False + self, + gm, + mesh, + rescale_grad_comm_cost_for_mp=1.0, + repeated_subgraphs=False, + enable_asynctp=False, ): self.gm = gm self.graph = gm.graph @@ -139,11 +146,13 @@ def __init__( self.strats = self.build_sharding_metadata() self.cluster_links = {} + repeated_subgraphs = True if repeated_subgraphs: t = time.time() clusters = get_identical_regions(self.gm.graph, self.strats) print(f"Found {len(clusters)} clusters in {time.time() - t:.2f}s") self.create_cluster_links(clusters) + self.enable_asynctp = enable_asynctp # ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data # Each key represents a choice of input placement ii and output placement ss @@ -556,11 +565,11 @@ def print_old(self): print(self.get_violated_constraints_log()) def get_log(self, colored=False): - from torch.fx.graph import _color_fns, _identity opt = {} nodes = list(self.graph.nodes) + log_shapes = False for x in self.res: opt.setdefault(nodes[x[0]], []).append(self.ds[x]) @@ -600,10 +609,28 @@ def get_log(self, colored=False): code.insert(l_id, line) l_id += 1 continue + log_extra = "" + if log_shapes: + if ( + isinstance(node, torch.fx.Node) + and "val" in node.meta + and isinstance(node.meta["val"], torch.Tensor) + ): + log_extra += "(" + for arg in node.args: + if ( + isinstance(arg, torch.fx.Node) + and "val" in arg.meta + and isinstance(arg.meta["val"], torch.Tensor) + ): + log_extra += str(list(arg.meta["val"].shape)) + log_extra += ") -> " + log_extra += str(list(node.meta["val"].shape)) + log_extra += "\n" # LOL while not code[l_id].lstrip().startswith(repr(node)): l_id += 1 - code[l_id] += line + code[l_id] = log_extra + code[l_id] + line l_id += 1 code = "\n".join(code) total_cost = sum(self.ds[x]["cost"] for x in self.res) @@ -641,6 +668,11 @@ def get_solution(self, verbose=False): # add their costs for x in self.ds.values(): opt_target[x["va"]] += x["cost"] + + if self.enable_asynctp: + for va, cost in self.add_asynctp_scores().items(): + opt_target[va] += cost + self.prob += pulp.lpSum([va * cost for va, cost in opt_target.items()]) # solver = pulp.HiGHS(msg=verbose) @@ -885,6 +917,105 @@ def add_sharded_output_constraint(self, output_placements=None): "them from the graph to avoid aliasing." ) + def add_asynctp_scores(self): + # Encourage placements that enable asyncTP fusions: + # -X % of comm_cost + # 1. ag + mm: S(d) -> R, d < mm.ndim - 1 + # 2. mm + rs: P -> S(d) + # TODO1: Filter out FSDP ag/rs that will not be asyncTPed + # TODO2: With AsyncTP we should have perf wins, + # overlapping ((group_size - 1) / group_size) of communication + # minus cost of decomposition. + # For this we need to get group_size from the redistribution. + def _get_transformations(src_spec, tgt_spec): + # TODO: Use real transform preparation + # For now just checking left to right + src_pls = src_spec.placements + tgt_pls = tgt_spec.placements + transformations = [] + for src_pl, tgt_pl in zip(src_pls, tgt_pls): + if src_pl == tgt_pl: + continue + transformations.append((src_pl, tgt_pl)) + return transformations + + def _produces_asynctp_ag(src_spec, tgt_spec, mm_dim): + # Check that the last transition will be S(dim) -> Replicate + + transformations = _get_transformations(src_spec, tgt_spec) + if len(transformations) == 0: + return False + last_t = transformations[-1] + return ( + last_t[1].is_replicate() + and last_t[0].is_shard() + and last_t[0].dim < mm_dim - 1 + ) + + def _produces_asynctp_rs(src_spec, tgt_spec, mm_dim): + # Check that the last transition will be P -> S(dim) + transformations = _get_transformations(src_spec, tgt_spec) + if len(transformations) == 0: + return False + last_t = transformations[-1] + return last_t[0].is_partial() and last_t[1].is_shard() + + va_cost_delta = defaultdict(int) + strats = self.strats + for s_i, (node, s) in enumerate(strats.items()): + if not (node.op == "call_function" and node.target == aten.mm.default): + continue + mm_n = node + # Incentivize ag+mm + # ard0 of MM should be S(dim) -> R to have all_gather before mm + a_n = node.args[0] + mm_sts = s.strategies + for mm_st_i, mm_st in enumerate(mm_sts): + a_sts = strats[a_n].strategies + mm_tgt_spec = mm_st.input_specs[0] + for a_st_i, a_st in enumerate(a_sts): + a_src_spec = a_st.output_spec + # TODO: Is adding constraint to arg is enough or we need to follow the arg + # ancestors and find the first sharding change? + if _produces_asynctp_ag( + a_src_spec, mm_tgt_spec, mm_n.meta["val"].ndim + ): + # TODO: We want to to calculate the cost of specific AG, as it will be pipelined, + # for now using just redistribution cost + cost = mm_st.redistribute_cost[0][a_st_i] + if cost == float("inf"): + continue + va = self.ds[(s_i, 0, mm_st_i, a_st_i)]["va"] + va_cost_delta[va] += -0.3 * cost + # mm+rs + src_spec = mm_st.output_spec + if len(mm_n.users) == 0: + continue + mm_user = next(iter(mm_n.users)) + mm_user_s_i = self.node_map[mm_user] + mm_u_arg_mm_i = -1 + for i, arg in enumerate(mm_user.args): + if arg == mm_n: + mm_u_arg_mm_i = i + assert mm_u_arg_mm_i != -1 + mm_user_sts = strats[mm_user].strategies + for mm_u_st_i, mm_u_st in enumerate(mm_user_sts): + if _produces_asynctp_rs( + src_spec, + mm_u_st.input_specs[mm_u_arg_mm_i], + mm_n.meta["val"].ndim, + ): + # TODO: We want to to calculate the cost of specific RS, as it will be pipelined, + # for now using just redistribution cost + cost = mm_u_st.redistribute_cost[mm_u_arg_mm_i][mm_u_st_i] + if cost == float("inf"): + continue + key = (mm_user_s_i, mm_u_arg_mm_i, mm_u_st_i, mm_st_i) + va = self.ds[key]["va"] + va_cost_delta[va] += -0.3 * cost + + return va_cost_delta + def validate(self): for node in self.graph.nodes: if node.op != "call_function": diff --git a/examples/example_llama3.py b/examples/example_llama3.py index bc41e96c..8d7468cb 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -211,7 +211,13 @@ def add_tp_constraints(autop): # parallelize the model with AutoParallel( - model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True + model, + input_fn, + mesh, + mp_policy, + compile=True, + repeated_subgraphs=True, + enable_asynctp=enable_asynctp, ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) @@ -229,22 +235,55 @@ def add_tp_constraints(autop): if enable_manual_constraint and not use_1d_mesh: add_tp_constraints(autop) - if enable_asynctp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group + enable_overlap_scheduling = True + enable_overlap_scheduling_bucketing = True + if enable_overlap_scheduling_bucketing: + assert ( + enable_overlap_scheduling + ), "bucketing can not be used without overlap scheduling" + enable_asynctp = True + from autoparallel.asynctp import _micro_pipeline_tp_ag_transpose_mm_enabled, _micro_pipeline_tp_ag_mm_last_dim_enabled + _micro_pipeline_tp_ag_transpose_mm_enabled = True + _micro_pipeline_tp_ag_mm_last_dim_enabled = True + if ( + enable_overlap_scheduling + or enable_overlap_scheduling_bucketing + or enable_asynctp + ): + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = ( + enable_overlap_scheduling_bucketing + ) + + if enable_asynctp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group - enable_symm_mem_for_group(mesh["dp"].get_group().group_name) - enable_symm_mem_for_group(mesh["tp"].get_group().group_name) - torch._inductor.config._micro_pipeline_tp = False - from autoparallel.asynctp import micro_pipeline_tp_pass + enable_symm_mem_for_group(mesh["tp"].get_group().group_name) + enable_symm_mem_for_group(mesh["dp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = False + # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork. + # TODO: Switch to Inductor AsyncTP passes, when all additions landed. + from autoparallel.asynctp import micro_pipeline_tp_pass existing_post_grad_custom_post_pass = ( torch._inductor.config.post_grad_custom_post_pass ) + from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler def _pass(graph): if existing_post_grad_custom_post_pass is not None: existing_post_grad_custom_post_pass(graph) - micro_pipeline_tp_pass(graph) + + collective_info = None + if enable_overlap_scheduling: + overlap_scheduler = OverlapScheduler(graph.owning_module) + overlap_scheduler.run() + collective_info = overlap_scheduler.collective_info + + if enable_asynctp: + micro_pipeline_tp_pass(graph, collective_info) torch._inductor.config.post_grad_custom_post_pass = _pass diff --git a/mast/sweep.py b/mast/sweep.py index e867bd43..130cace2 100644 --- a/mast/sweep.py +++ b/mast/sweep.py @@ -123,6 +123,12 @@ def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]: "--model.name=llama3", "--compile.enable", ], + "llama3_FSDP_tp_async_tp_compile": llama3_2d_common_opts + + [ + "--model.name=llama3", + "--compile.enable", + "--parallelism.enable_async_tensor_parallel", + ], "llama3_autop_2d_compile": llama3_2d_common_opts + [ "--model.name=llama3_auto_parallel",