From c1eaf05ceffaa3502d06758237413b41b70fd72b Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Thu, 16 Oct 2025 12:36:32 -0700 Subject: [PATCH 1/2] Graph Multiplex Pass Example --- autoparallel/_passes/graph_multiplex.py | 100 ++++++++++++++++++++++++ autoparallel/api.py | 5 +- examples/example_llama3.py | 11 ++- 3 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 autoparallel/_passes/graph_multiplex.py diff --git a/autoparallel/_passes/graph_multiplex.py b/autoparallel/_passes/graph_multiplex.py new file mode 100644 index 00000000..98867d32 --- /dev/null +++ b/autoparallel/_passes/graph_multiplex.py @@ -0,0 +1,100 @@ +import copy + +import torch +import torch.fx as fx + + +def multiplex_fw_bw_graph( + fw_gm: fx.GraphModule, bw_gm: fx.GraphModule +) -> fx.GraphModule: + """ + Multiplexes forward and backward graphs into a single unified graph module. + + This function combines a forward graph and a backward graph into one multiplexed + graph by merging their nodes and outputs. The resulting graph has: + - All placeholders from both forward and backward graphs (backward followed by forward) + - All computation nodes from both graphs (backward followed by forward) + - Combined outputs (backward outputs followed by forward outputs) + + Args: + fw_gm: The forward graph module containing the forward computation + bw_gm: The backward graph module containing the backward computation + + Returns: + A multiplexed fx.GraphModule containing both forward and backward computations + with backward outputs appearing before forward outputs + + Note: + The function preserves node metadata during the merging process. + """ + # Mapping to track correspondence between backward graph nodes and new nodes + old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {} + + # Start with a deep copy of the forward graph as the base + multiplexed_gm = copy.deepcopy(fw_gm) + + # Collect all placeholder nodes from the backward graph + bw_placeholders = [] + for n in bw_gm.graph.nodes: + if n.op == "placeholder": + bw_placeholders.append(n) + + # Insert backward placeholders at the beginning of the multiplexed graph + # Reversed order ensures correct execution sequence + with multiplexed_gm.graph.inserting_before(): + for n in reversed(bw_placeholders): + new_placeholder = multiplexed_gm.graph.placeholder(n.name) + new_placeholder.meta = n.meta + new_placeholder.target = new_placeholder.name + old_node_to_new_node[n] = new_placeholder + + # Find the last placeholder and the output node in the multiplexed graph + insert_point = None + multiplexed_graph_op_node = None + for n in multiplexed_gm.graph.nodes: + if n.op == "placeholder": + insert_point = n + if n.op == "output": + multiplexed_graph_op_node = n + + # Copy all computation nodes from backward graph into multiplexed graph + bw_graph_op_node = None + for n in bw_gm.graph.nodes: + if n.op == "placeholder": + continue + if n.op == "output": + bw_graph_op_node = n + continue + with multiplexed_gm.graph.inserting_after(insert_point): + # Copy node and remap its arguments using the node mapping + new_node = multiplexed_gm.graph.node_copy( + n, lambda x: old_node_to_new_node[x] + ) + new_node.meta = n.meta + old_node_to_new_node[n] = new_node + insert_point = new_node + + assert bw_graph_op_node is not None + assert multiplexed_graph_op_node is not None + + # Collect output arguments from backward graph, remapping to new nodes + bw_op_node_args = [ + old_node_to_new_node[n] if n is not None else None + for n in bw_graph_op_node.args[0] + ] + + # Collect output arguments from forward graph + fw_op_node_args = list(multiplexed_graph_op_node.args[0]) + + # Remove the old output node and create new combined output + insert_point = multiplexed_graph_op_node.prev + multiplexed_gm.graph.erase_node(multiplexed_graph_op_node) + + # Create combined output with backward outputs first, then forward outputs + with multiplexed_gm.graph.inserting_after(insert_point): + multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args) + + multiplexed_gm.graph.eliminate_dead_code() + multiplexed_gm.graph.lint() + multiplexed_gm.recompile() + return multiplexed_gm diff --git a/autoparallel/api.py b/autoparallel/api.py index abbbc6df..dc7ec94c 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -467,11 +467,14 @@ def apply_placement(self, sharding_placement=None): torch.ops._c10d_functional.wait_tensor.default ) - self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( + parallel_model_fn, fw_module, bw_module = aot_compile_joint_with_descriptors( self.joint_with_descriptors, fw_compiler=self.compiler_fn, bw_compiler=self.compiler_fn, ) + self.parallel_model_fn = parallel_model_fn + self.fw_module = fw_module + self.bw_module = bw_module # TODO: this probably belongs in the AOTAutograd API # TODO: pytree handling diff --git a/examples/example_llama3.py b/examples/example_llama3.py index bc41e96c..60317579 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -11,6 +11,7 @@ from torch.distributed.tensor.placement_types import Partial, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore +from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel from autoparallel.auto_bucketing import ( @@ -57,7 +58,7 @@ def model_fn(): if model_type == "8b": model_args = TransformerModelArgs( dim=4096, - n_layers=32, + n_layers=1, n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3, @@ -252,6 +253,14 @@ def _pass(graph): sharding_placement = autop.optimize_placement(verbose=True) print(f"Took {time.time() - t:.2f} s") parallel_mod = autop.apply_placement(sharding_placement) + multiplex_graph = True + if multiplex_graph: + f_gm = autop.fw_module + b_gm = autop.bw_module + multiplexed_gm = multiplex_fw_bw_graph(f_gm, b_gm) + print(f_gm.graph) + print(b_gm.graph) + print(multiplexed_gm.graph) # run weight init on our sharded DTensor params parallel_mod.to_empty(device="cuda") From 962a2313090d440a854954d20bfadbf10583b58b Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Thu, 16 Oct 2025 13:38:10 -0700 Subject: [PATCH 2/2] Adding split FSDP Collective Pass --- autoparallel/_passes/graph_multiplex.py | 5 ++ .../_passes/split_fsdp_collectives.py | 61 +++++++++++++++++++ examples/example_llama3.py | 16 ++++- 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 autoparallel/_passes/split_fsdp_collectives.py diff --git a/autoparallel/_passes/graph_multiplex.py b/autoparallel/_passes/graph_multiplex.py index 98867d32..e1eadc8d 100644 --- a/autoparallel/_passes/graph_multiplex.py +++ b/autoparallel/_passes/graph_multiplex.py @@ -1,3 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + import copy import torch diff --git a/autoparallel/_passes/split_fsdp_collectives.py b/autoparallel/_passes/split_fsdp_collectives.py new file mode 100644 index 00000000..387b9e4a --- /dev/null +++ b/autoparallel/_passes/split_fsdp_collectives.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses + +import torch +import torch.utils._pytree as pytree +from torch._functorch._aot_autograd.descriptors import AOTOutput +from torch._functorch.partitioners import _extract_graph_with_inputs_outputs + + +@dataclasses.dataclass(frozen=True) +class PrefetchOutput(AOTOutput): + pass + + +def split_fsdp_prefetch( + gm: torch.fx.GraphModule, +) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + g = gm.graph + g_ins = g.find_nodes(op="placeholder") + prefetch_g_outs_map = {} + + for g_in in g_ins: + n = g_in + while True: + if len(n.users) != 1: + break + user = next(iter(n.users)) + if len(user.all_input_nodes) > 1: + break + n = user + prefetch_g_outs_map[g_in] = n + + prefetch_g_outs = list(prefetch_g_outs_map.values()) + prefetch_g_outs_descs: list[AOTOutput] = [ + PrefetchOutput() for _ in range(len(prefetch_g_outs)) + ] + + prefetch_g = _extract_graph_with_inputs_outputs( + g, + g_ins, + prefetch_g_outs, + prefetch_g_outs_descs, + ) + + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) + g_outs_descs = pytree.arg_tree_leaves( + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) + ) + main_g = _extract_graph_with_inputs_outputs( + g, + prefetch_g_outs, + g_outs, + g_outs_descs, + ) + main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) + prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g) + return prefetch_gm, main_gm diff --git a/examples/example_llama3.py b/examples/example_llama3.py index 60317579..99f63aa5 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -12,6 +12,7 @@ from torch.testing._internal.distributed.fake_pg import FakeStore from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph +from autoparallel._passes.split_fsdp_collectives import split_fsdp_prefetch from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel from autoparallel.auto_bucketing import ( @@ -257,9 +258,22 @@ def _pass(graph): if multiplex_graph: f_gm = autop.fw_module b_gm = autop.bw_module - multiplexed_gm = multiplex_fw_bw_graph(f_gm, b_gm) + print("Original Fwd Graph:") print(f_gm.graph) + print("Original Bwd Graph:") print(b_gm.graph) + prefetch_f_gm, main_f_gm = split_fsdp_prefetch(f_gm) + print("Main Fwd Graph:") + print(main_f_gm.graph) + print("Prefetch Fwd Graph:") + print(prefetch_f_gm.graph) + prefetch_b_gm, main_b_gm = split_fsdp_prefetch(b_gm) + print("Main Bwd Graph:") + print(main_b_gm.graph) + print("Prefetch Bwd Graph:") + print(prefetch_b_gm.graph) + multiplexed_gm = multiplex_fw_bw_graph(main_f_gm, main_b_gm) + print("Multiplexed Graph:") print(multiplexed_gm.graph) # run weight init on our sharded DTensor params