diff --git a/autoparallel/asynctp.py b/autoparallel/asynctp.py index 9b708ccd..d80f6b59 100644 --- a/autoparallel/asynctp.py +++ b/autoparallel/asynctp.py @@ -13,6 +13,7 @@ import torch from torch._inductor import inductor_prims +from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo from torch._inductor.pattern_matcher import ( MULTIPLE, CallFunction, @@ -23,6 +24,7 @@ PatternExpr, PatternMatcherPass, ) +from torch._logging import trace_structured from torch.utils._ordered_set import OrderedSet import autoparallel.asynctp_ops # noqa: F401 @@ -34,7 +36,7 @@ _micro_pipeline_tp_ag_transpose_mm_enabled = True # Check performance if overhead of decomposition outweights pipeline wins -_micro_pipeline_tp_ag_mm_last_dim_enabled = False +_micro_pipeline_tp_ag_mm_last_dim_enabled = True _micro_pipeline_tp_mm_rs_last_dim_enabled = True @@ -720,7 +722,7 @@ def _insert_fused_all_gather_transpose_matmul( raise AssertionError(f"Unexpected matmul match type: {mm_type}") -def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: +def fuse_all_gather_matmul(all_gather: _AllGatherMatch, log_strs) -> None: """ Fused the pattern @@ -755,6 +757,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: all_gather.group_name, ) + log_strs.append(f"fuse_agmm {all_gather}") if not is_symm_mem_enabled_for_group(group_name): return @@ -774,6 +777,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: for matmul in matmuls if all_gather.res_node not in matmul.arg_ancestor_nodes ] + log_strs.append(f"fuse_agmm matmuls:{matmuls}") if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1: return @@ -870,6 +874,7 @@ def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: for node in nodes_to_raise: if order[node] > order[fused_node]: fused_node.prepend(node) + log_strs.append("fuse_agmm DONE") def _scatter_dim_after_reshape( @@ -990,7 +995,7 @@ def _insert_fused_matmul_reduce_scatter( raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") -def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: +def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch, log_strs) -> None: """ Fused the pattern @@ -1004,6 +1009,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: Returns boolean indicating if fusion was successful or not. """ + log_strs.append(f"fuse_mmrs {reduce_scatter}") if ( not torch.distributed.is_available() or not torch.distributed.is_nccl_available() @@ -1032,6 +1038,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: ) if not is_symm_mem_enabled_for_group(group_name): + log_strs.append("fuse_mmrs not symm mem group") return if ( @@ -1048,16 +1055,19 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: log.warning( "matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion." ) + log_strs.append("fuse_mmrs input.node.users != 1") return matmul = _find_producer_matmul(input_node) if matmul is None: + log_strs.append("fuse_mmrs no matmul") log.warning( "no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion" ) return if rs_wait_tensor_node in matmul.arg_ancestor_nodes: + log_strs.append("fuse_mmrs wait in matmul.arg_ancestors") log.warning( "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion" ) @@ -1123,6 +1133,7 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: if order[node] > order[fused_node]: fused_node.prepend(node) + log_strs.append("fuse_mmrs DONE") log.debug("successfully fused matmul reduce scatter") @@ -1173,6 +1184,7 @@ def is_collective(node) -> bool: return collective_to_overlappable_nodes +# TODO: Convert return type to set def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]: """ Find all unexposed collectives in the graph. @@ -1209,18 +1221,47 @@ def _is_compute_intensive(node: torch.fx.Node) -> bool: return unexposed_collectives -def micro_pipeline_tp_pass(graph: torch.fx.Graph): +def micro_pipeline_tp_pass( + graph: torch.fx.Graph, + collective_info: Optional[dict[torch.fx.Node, CollectiveInfo]] = None, +): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "asynctp_pre_graph", + "encoding": "string", + }, + payload_fn=lambda: graph.owning_module.print_readable(False), + ) all_gathers = find_all_gather_patterns(graph) reduce_scatters = find_reduce_scatter_patterns(graph) # When a collective can be hidden through either simple overlapping or # micro-pipeline TP, we prefer simple overlapping to avoid the overhead # associated with decomposition. - unexposed_collectives = _get_unexposed_collectives(graph) + # Known problem: collective info contains only pre-bucketing collectives + if collective_info is None: + unexposed_collectives = _get_unexposed_collectives(graph) + else: + unexposed_collectives = [ + n + for n, col_info in collective_info.items() + if col_info.hiding_node is not None + ] + + log_strs = [] + log_strs.append(f"\n all_gathers:{all_gathers}") + log_strs.append(f"\n reduce_scatters:{reduce_scatters}") + all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] reduce_scatters = [ x for x in reduce_scatters if x.reduce_scatter_node not in unexposed_collectives ] + for n, coll_info in collective_info.items(): + log_strs.append(f"coll_info {n}: {coll_info}") + log_strs.append(f"\n unexposed_collectives:{unexposed_collectives}") + log_strs.append(f"\n all_gathers_exposed:{all_gathers}") + log_strs.append(f"\n reduce_scatters_exposed:{reduce_scatters}") if not all_gathers and not reduce_scatters: log.warning( @@ -1228,7 +1269,23 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph): ) for reduce_scatter in reduce_scatters: - fuse_matmul_reduce_scatter(reduce_scatter) + fuse_matmul_reduce_scatter(reduce_scatter, log_strs) for all_gather in all_gathers: - fuse_all_gather_matmul(all_gather) + fuse_all_gather_matmul(all_gather, log_strs) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "asynctp_log", + "encoding": "string", + }, + payload_fn=lambda: "\n".join(log_strs), + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "asynctp_post_graph", + "encoding": "string", + }, + payload_fn=lambda: graph.owning_module.print_readable(False), + ) diff --git a/examples/example_llama3.py b/examples/example_llama3.py index 27cb8633..dbbbb5c7 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -192,22 +192,53 @@ def add_tp_constraints(autop): if enable_manual_constraint and not use_1d_mesh: add_tp_constraints(autop) - if enable_asynctp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group + enable_overlap_scheduling = True + enable_overlap_scheduling_bucketing = True + if enable_overlap_scheduling_bucketing: + assert ( + enable_overlap_scheduling + ), "bucketing can not be used without overlap scheduling" + enable_asynctp = True + + if ( + enable_overlap_scheduling + or enable_overlap_scheduling_bucketing + or enable_asynctp + ): + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = ( + enable_overlap_scheduling_bucketing + ) + + if enable_asynctp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group - enable_symm_mem_for_group(mesh["dp"].get_group().group_name) - enable_symm_mem_for_group(mesh["tp"].get_group().group_name) - torch._inductor.config._micro_pipeline_tp = False - from autoparallel.asynctp import micro_pipeline_tp_pass + enable_symm_mem_for_group(mesh["tp"].get_group().group_name) + enable_symm_mem_for_group(mesh["dp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = False + # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork. + # TODO: Switch to Inductor AsyncTP passes, when all additions landed. + from autoparallel.asynctp import micro_pipeline_tp_pass existing_post_grad_custom_post_pass = ( torch._inductor.config.post_grad_custom_post_pass ) + from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler def _pass(graph): if existing_post_grad_custom_post_pass is not None: existing_post_grad_custom_post_pass(graph) - micro_pipeline_tp_pass(graph) + + collective_info = None + if enable_overlap_scheduling: + overlap_scheduler = OverlapScheduler(graph.owning_module) + overlap_scheduler.run() + collective_info = overlap_scheduler.collective_info + + if enable_asynctp: + micro_pipeline_tp_pass(graph, collective_info) torch._inductor.config.post_grad_custom_post_pass = _pass