From c1eaf05ceffaa3502d06758237413b41b70fd72b Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Thu, 16 Oct 2025 12:36:32 -0700 Subject: [PATCH 1/3] Graph Multiplex Pass Example --- autoparallel/_passes/graph_multiplex.py | 100 ++++++++++++++++++++++++ autoparallel/api.py | 5 +- examples/example_llama3.py | 11 ++- 3 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 autoparallel/_passes/graph_multiplex.py diff --git a/autoparallel/_passes/graph_multiplex.py b/autoparallel/_passes/graph_multiplex.py new file mode 100644 index 00000000..98867d32 --- /dev/null +++ b/autoparallel/_passes/graph_multiplex.py @@ -0,0 +1,100 @@ +import copy + +import torch +import torch.fx as fx + + +def multiplex_fw_bw_graph( + fw_gm: fx.GraphModule, bw_gm: fx.GraphModule +) -> fx.GraphModule: + """ + Multiplexes forward and backward graphs into a single unified graph module. + + This function combines a forward graph and a backward graph into one multiplexed + graph by merging their nodes and outputs. The resulting graph has: + - All placeholders from both forward and backward graphs (backward followed by forward) + - All computation nodes from both graphs (backward followed by forward) + - Combined outputs (backward outputs followed by forward outputs) + + Args: + fw_gm: The forward graph module containing the forward computation + bw_gm: The backward graph module containing the backward computation + + Returns: + A multiplexed fx.GraphModule containing both forward and backward computations + with backward outputs appearing before forward outputs + + Note: + The function preserves node metadata during the merging process. + """ + # Mapping to track correspondence between backward graph nodes and new nodes + old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {} + + # Start with a deep copy of the forward graph as the base + multiplexed_gm = copy.deepcopy(fw_gm) + + # Collect all placeholder nodes from the backward graph + bw_placeholders = [] + for n in bw_gm.graph.nodes: + if n.op == "placeholder": + bw_placeholders.append(n) + + # Insert backward placeholders at the beginning of the multiplexed graph + # Reversed order ensures correct execution sequence + with multiplexed_gm.graph.inserting_before(): + for n in reversed(bw_placeholders): + new_placeholder = multiplexed_gm.graph.placeholder(n.name) + new_placeholder.meta = n.meta + new_placeholder.target = new_placeholder.name + old_node_to_new_node[n] = new_placeholder + + # Find the last placeholder and the output node in the multiplexed graph + insert_point = None + multiplexed_graph_op_node = None + for n in multiplexed_gm.graph.nodes: + if n.op == "placeholder": + insert_point = n + if n.op == "output": + multiplexed_graph_op_node = n + + # Copy all computation nodes from backward graph into multiplexed graph + bw_graph_op_node = None + for n in bw_gm.graph.nodes: + if n.op == "placeholder": + continue + if n.op == "output": + bw_graph_op_node = n + continue + with multiplexed_gm.graph.inserting_after(insert_point): + # Copy node and remap its arguments using the node mapping + new_node = multiplexed_gm.graph.node_copy( + n, lambda x: old_node_to_new_node[x] + ) + new_node.meta = n.meta + old_node_to_new_node[n] = new_node + insert_point = new_node + + assert bw_graph_op_node is not None + assert multiplexed_graph_op_node is not None + + # Collect output arguments from backward graph, remapping to new nodes + bw_op_node_args = [ + old_node_to_new_node[n] if n is not None else None + for n in bw_graph_op_node.args[0] + ] + + # Collect output arguments from forward graph + fw_op_node_args = list(multiplexed_graph_op_node.args[0]) + + # Remove the old output node and create new combined output + insert_point = multiplexed_graph_op_node.prev + multiplexed_gm.graph.erase_node(multiplexed_graph_op_node) + + # Create combined output with backward outputs first, then forward outputs + with multiplexed_gm.graph.inserting_after(insert_point): + multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args) + + multiplexed_gm.graph.eliminate_dead_code() + multiplexed_gm.graph.lint() + multiplexed_gm.recompile() + return multiplexed_gm diff --git a/autoparallel/api.py b/autoparallel/api.py index abbbc6df..dc7ec94c 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -467,11 +467,14 @@ def apply_placement(self, sharding_placement=None): torch.ops._c10d_functional.wait_tensor.default ) - self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors( + parallel_model_fn, fw_module, bw_module = aot_compile_joint_with_descriptors( self.joint_with_descriptors, fw_compiler=self.compiler_fn, bw_compiler=self.compiler_fn, ) + self.parallel_model_fn = parallel_model_fn + self.fw_module = fw_module + self.bw_module = bw_module # TODO: this probably belongs in the AOTAutograd API # TODO: pytree handling diff --git a/examples/example_llama3.py b/examples/example_llama3.py index bc41e96c..60317579 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -11,6 +11,7 @@ from torch.distributed.tensor.placement_types import Partial, Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore +from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel from autoparallel.auto_bucketing import ( @@ -57,7 +58,7 @@ def model_fn(): if model_type == "8b": model_args = TransformerModelArgs( dim=4096, - n_layers=32, + n_layers=1, n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3, @@ -252,6 +253,14 @@ def _pass(graph): sharding_placement = autop.optimize_placement(verbose=True) print(f"Took {time.time() - t:.2f} s") parallel_mod = autop.apply_placement(sharding_placement) + multiplex_graph = True + if multiplex_graph: + f_gm = autop.fw_module + b_gm = autop.bw_module + multiplexed_gm = multiplex_fw_bw_graph(f_gm, b_gm) + print(f_gm.graph) + print(b_gm.graph) + print(multiplexed_gm.graph) # run weight init on our sharded DTensor params parallel_mod.to_empty(device="cuda") From 962a2313090d440a854954d20bfadbf10583b58b Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Thu, 16 Oct 2025 13:38:10 -0700 Subject: [PATCH 2/3] Adding split FSDP Collective Pass --- autoparallel/_passes/graph_multiplex.py | 5 ++ .../_passes/split_fsdp_collectives.py | 61 +++++++++++++++++++ examples/example_llama3.py | 16 ++++- 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 autoparallel/_passes/split_fsdp_collectives.py diff --git a/autoparallel/_passes/graph_multiplex.py b/autoparallel/_passes/graph_multiplex.py index 98867d32..e1eadc8d 100644 --- a/autoparallel/_passes/graph_multiplex.py +++ b/autoparallel/_passes/graph_multiplex.py @@ -1,3 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + import copy import torch diff --git a/autoparallel/_passes/split_fsdp_collectives.py b/autoparallel/_passes/split_fsdp_collectives.py new file mode 100644 index 00000000..387b9e4a --- /dev/null +++ b/autoparallel/_passes/split_fsdp_collectives.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses + +import torch +import torch.utils._pytree as pytree +from torch._functorch._aot_autograd.descriptors import AOTOutput +from torch._functorch.partitioners import _extract_graph_with_inputs_outputs + + +@dataclasses.dataclass(frozen=True) +class PrefetchOutput(AOTOutput): + pass + + +def split_fsdp_prefetch( + gm: torch.fx.GraphModule, +) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + g = gm.graph + g_ins = g.find_nodes(op="placeholder") + prefetch_g_outs_map = {} + + for g_in in g_ins: + n = g_in + while True: + if len(n.users) != 1: + break + user = next(iter(n.users)) + if len(user.all_input_nodes) > 1: + break + n = user + prefetch_g_outs_map[g_in] = n + + prefetch_g_outs = list(prefetch_g_outs_map.values()) + prefetch_g_outs_descs: list[AOTOutput] = [ + PrefetchOutput() for _ in range(len(prefetch_g_outs)) + ] + + prefetch_g = _extract_graph_with_inputs_outputs( + g, + g_ins, + prefetch_g_outs, + prefetch_g_outs_descs, + ) + + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) + g_outs_descs = pytree.arg_tree_leaves( + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) + ) + main_g = _extract_graph_with_inputs_outputs( + g, + prefetch_g_outs, + g_outs, + g_outs_descs, + ) + main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) + prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g) + return prefetch_gm, main_gm diff --git a/examples/example_llama3.py b/examples/example_llama3.py index 60317579..99f63aa5 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -12,6 +12,7 @@ from torch.testing._internal.distributed.fake_pg import FakeStore from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph +from autoparallel._passes.split_fsdp_collectives import split_fsdp_prefetch from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs from autoparallel.api import AutoParallel from autoparallel.auto_bucketing import ( @@ -257,9 +258,22 @@ def _pass(graph): if multiplex_graph: f_gm = autop.fw_module b_gm = autop.bw_module - multiplexed_gm = multiplex_fw_bw_graph(f_gm, b_gm) + print("Original Fwd Graph:") print(f_gm.graph) + print("Original Bwd Graph:") print(b_gm.graph) + prefetch_f_gm, main_f_gm = split_fsdp_prefetch(f_gm) + print("Main Fwd Graph:") + print(main_f_gm.graph) + print("Prefetch Fwd Graph:") + print(prefetch_f_gm.graph) + prefetch_b_gm, main_b_gm = split_fsdp_prefetch(b_gm) + print("Main Bwd Graph:") + print(main_b_gm.graph) + print("Prefetch Bwd Graph:") + print(prefetch_b_gm.graph) + multiplexed_gm = multiplex_fw_bw_graph(main_f_gm, main_b_gm) + print("Multiplexed Graph:") print(multiplexed_gm.graph) # run weight init on our sharded DTensor params From 0839c3551a3a591fe7bcff09dbe70f5578c6115d Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 28 Oct 2025 19:29:42 -0700 Subject: [PATCH 3/3] Add split dI/dW graph pass example --- autoparallel/_passes/split_di_dw_graph.py | 50 ++++ examples/example_llama3_di_dw.py | 320 ++++++++++++++++++++++ 2 files changed, 370 insertions(+) create mode 100644 autoparallel/_passes/split_di_dw_graph.py create mode 100644 examples/example_llama3_di_dw.py diff --git a/autoparallel/_passes/split_di_dw_graph.py b/autoparallel/_passes/split_di_dw_graph.py new file mode 100644 index 00000000..a7dc0433 --- /dev/null +++ b/autoparallel/_passes/split_di_dw_graph.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import torch +import torch.fx as fx +from functorch.compile import default_partition + +# we are running the default partitioner on the bw graph, which requires AC tags being removed. +# At this stage we have already finished running AC anyway, since we have a bw graph +def remove_recompute_tags(bw_gm): + for n in bw_gm.graph.nodes: + if 'recompute' in n.meta: + del n.meta['recompute'] + +# We are using the default partitioner to split our backward into dI and dW subgraphs. +# We want to generate the dI subgraph *first*, because: +# - in pipelining we generally want to schedule dI compute before dW +# - the dI compute will potentially compute more activations that we need to plumb into dW compute +# Today, the default partitioner requires that your split on the first K outputs of your combined graph. +# So here, we reorder the outputs of the backward so grad_inputs are first. +def reorder_output_grads(bw_gm, num_weight_gradients): + outputs = bw_gm.graph.find_nodes(op='output') + assert len(outputs) == 1 + output = outputs[0] + assert isinstance(output.args[0], tuple) + grad_weights, grad_inputs = output.args[0][:num_weight_gradients], output.args[0][num_weight_gradients:] + new_out_tuple = grad_inputs + grad_weights + with bw_gm.graph.inserting_after(output): + # TODO: also set the new node's meta properly + new_out = bw_gm.graph.output(new_out_tuple) + output.replace_all_uses_with(new_out) + bw_gm.graph.erase_node(output) + return len(grad_inputs) + +# TODO: in theory we can infer num_weight_gradients from the graph metadata directly +def split_di_dw_graph(bw_gm: fx.GraphModule, *, num_weight_gradients) -> tuple[fx.GraphModule, fx.GraphModule]: + # we could consider doing this is a non-mutating way + bw_gm = copy.deepcopy(bw_gm) + remove_recompute_tags(bw_gm) + num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients) + bw_gm.recompile() + + args = [x.meta['val'] for x in bw_gm.graph.find_nodes(op="placeholder")] + + bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients) + return bw_inputs, bw_weights diff --git a/examples/example_llama3_di_dw.py b/examples/example_llama3_di_dw.py new file mode 100644 index 00000000..c4c6d1bb --- /dev/null +++ b/examples/example_llama3_di_dw.py @@ -0,0 +1,320 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import time +from functools import partial + +import torch +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Partial, Replicate, Shard +from torch.testing._internal.distributed.fake_pg import FakeStore + +from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph +from autoparallel._passes.split_di_dw_graph import split_di_dw_graph +from autoparallel._passes.split_fsdp_collectives import split_fsdp_prefetch +from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import ( + aten_autobucketing_config, + aten_autobucketing_reordering_pass, + simple_fsdp_autobucketing_reordering_pass, + simplefsdp_autobucketing_config, +) + +world_size = 64 + +fake_store = FakeStore() +torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size +) + +use_1d_mesh = False + +if use_1d_mesh: + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", (world_size,), mesh_dim_names=("dp",) + ) +else: + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (world_size // 8, 8), + mesh_dim_names=( + "dp", + "tp", + ), + ) + +batch_size = 2 * mesh.shape[0] +seqlen = 2048 * 4 +vocab_size = 128256 +use_vocab_parallel = not use_1d_mesh +device = torch.device("cuda") + +model_type = "8b" +enable_asynctp = False + + +def model_fn(): + if model_type == "8b": + model_args = TransformerModelArgs( + dim=4096, + n_layers=1, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + elif model_type == "70b": + model_args = TransformerModelArgs( + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + else: + raise ValueError(f"{model_type} not available") + m = Transformer(model_args) + # I turned of the tok_embeddings layer because: + # - I want the input to my joint graph to require grad, + # so I can generate separate dI and dW gradients to carve out subgraphs for + # - The input to tok_embeddings is an integral tensor which can't require grad. + # Another option would be to manually apply autoparallel to a single transformer block layer. + m.tok_embeddings = None + return m + + +def input_fn(): + # 8192 for 70B + x = torch.randn(batch_size, seqlen, 4096, device=device, requires_grad=True) + return x + + +autobucketing_level = "aten" + +if autobucketing_level == "aten": + # this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960 + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + aten_autobucketing_reordering_pass = partial( + aten_autobucketing_reordering_pass, + configs=aten_autobucketing_config, + ) + torch._inductor.config.post_grad_custom_post_pass = ( + aten_autobucketing_reordering_pass + ) +elif autobucketing_level == "inductor": + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = True + simplefsdp_autobucketing_config.calibrate_number = 5 + simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl" + simple_fsdp_autobucketing_reordering_pass = partial( + simple_fsdp_autobucketing_reordering_pass, + configs=simplefsdp_autobucketing_config, + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ + simple_fsdp_autobucketing_reordering_pass + ] +else: + raise ValueError(f"Unknown autobucketing_level {autobucketing_level}") + + +# parallelize the model +with torch.device("meta"): + model = model_fn() + +mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) + + +def group_mm_nodes_with_its_gradients(nodes): + fwd_nodes = [n for n in nodes if "nn_module_stack" in n.meta] + bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n.meta] + assert len(fwd_nodes) * 2 == len(bwd_nodes) + res = {} + for fwd_node in fwd_nodes: + o = [] + for bwd_node in bwd_nodes: + if fwd_node.meta["nn_module_stack"] == bwd_node.meta["fwd_nn_module_stack"]: + o.append(bwd_node) + assert len(o) == 2 + res[fwd_node] = o + return res + + +def force_tp_constraints(autop, mm_nodes, feat_dim=1, bwd_constraint=False): + # out = x @ w - S(0)R, RS(1) -> S(0)S(1) + # g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0) + # g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P + + add_node_constraint = autop.sharding_optimizer.add_node_constraint + fwd_bwd_groups = group_mm_nodes_with_its_gradients(mm_nodes) + fwd_nodes = list(fwd_bwd_groups.keys()) + dim1 = 0 if feat_dim == 1 else 1 + dim2 = 1 if feat_dim == 1 else 0 + # assume there are 7 mm nodes per transformer block + # skip last mm as it's the final projection layer + assert ( + len(fwd_nodes) - 1 + ) % 7 == 0, f"expected 7 mm nodes per transformer block, {len(fwd_nodes) - 1}" + for block in range(0, len(fwd_nodes) - 1, 7): + fwd_nodes_block = fwd_nodes[block : block + 7] + # force the first 3 mm nodes to be S(0)S(1) + the_nodes = fwd_nodes_block[:3] + fwd_nodes_block[4:6] + for n in the_nodes: + add_node_constraint(n, (Shard(0), Shard(feat_dim))) + add_node_constraint(n.all_input_nodes[0], (Shard(0), Replicate())) + add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(1))) + + if bwd_constraint: + bwd_nodes = fwd_bwd_groups[n] + # first is g_w, second is g_x + add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim1))) + add_node_constraint(bwd_nodes[1], (Shard(0), Partial())) + + # add reduction to finish TP, yielding S(0)P + the_nodes = fwd_nodes_block[3:4] + fwd_nodes_block[6:7] + for n in the_nodes: + add_node_constraint(n, (Shard(0), Partial())) + add_node_constraint(n.all_input_nodes[0], (Shard(0), Shard(feat_dim))) + add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(0))) + + if bwd_constraint: + bwd_nodes = fwd_bwd_groups[n] + # first is g_w, second is g_x + add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim2))) + add_node_constraint(bwd_nodes[1], (Shard(0), Shard(feat_dim))) + + +def add_tp_constraints(autop): + mm_nodes = autop.gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.mm.default + ) + einsum_nodes = autop.gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.einsum.default + ) + assert (len(mm_nodes) > 0) ^ ( + len(einsum_nodes) > 0 + ), f"only one should be non-empty, got {len(mm_nodes)} and {len(einsum_nodes)}" + feat_dim = 1 if len(mm_nodes) > 0 else 2 + tgt_nodes = mm_nodes + einsum_nodes + force_tp_constraints(autop, tgt_nodes, feat_dim=feat_dim, bwd_constraint=True) + + if einsum_nodes: + # add sequence parallelism if we have einsum nodes + autop.sharding_optimizer.add_node_constraint( + list(tgt_nodes[3].users)[0], (Shard(0), Shard(1)) + ) + autop.sharding_optimizer.add_node_constraint( + list(list(tgt_nodes[3].users)[0].users)[0], (Shard(0), Shard(1)) + ) + + +# parallelize the model +with AutoParallel( + model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True +) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) + out_sharding = x_sharding + if use_vocab_parallel: + # add vocab parallel constraint + assert mesh.ndim == 2, "Only 2d mesh supported here" + out_sharding = (Shard(0), Shard(2)) + + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + + enable_manual_constraint = False + if enable_manual_constraint and not use_1d_mesh: + add_tp_constraints(autop) + + if enable_asynctp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + enable_symm_mem_for_group(mesh["dp"].get_group().group_name) + enable_symm_mem_for_group(mesh["tp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = False + from autoparallel.asynctp import micro_pipeline_tp_pass + + existing_post_grad_custom_post_pass = ( + torch._inductor.config.post_grad_custom_post_pass + ) + + def _pass(graph): + if existing_post_grad_custom_post_pass is not None: + existing_post_grad_custom_post_pass(graph) + micro_pipeline_tp_pass(graph) + + torch._inductor.config.post_grad_custom_post_pass = _pass + + t = time.time() + sharding_placement = autop.optimize_placement(verbose=True) + print(f"Took {time.time() - t:.2f} s") + parallel_mod = autop.apply_placement(sharding_placement) + multiplex_graph = True + if multiplex_graph: + f_gm = autop.fw_module + b_gm = autop.bw_module + print("Original Fwd Graph:") + print(f_gm.graph) + print("Original Bwd Graph:") + print(b_gm.graph) + prefetch_f_gm, main_f_gm = split_fsdp_prefetch(f_gm) + print("Main Fwd Graph:") + print(main_f_gm.graph) + print("Prefetch Fwd Graph:") + print(prefetch_f_gm.graph) + prefetch_b_gm, main_b_gm = split_fsdp_prefetch(b_gm) + print("Main Bwd Graph:") + print(main_b_gm.graph) + print("Prefetch Bwd Graph:") + print(prefetch_b_gm.graph) + multiplexed_gm = multiplex_fw_bw_graph(main_f_gm, main_b_gm) + print("Multiplexed Graph:") + print(multiplexed_gm.graph) + # in the AOTAutograd bw graph, the first K outputs correspond + # to gradients for any params/buffers in the user model + num_weight_gradients = autop.joint_with_descriptors._aot_state.aot_config.num_params_buffers + main_b_gm_di, main_b_gm_dw = split_di_dw_graph(main_b_gm, num_weight_gradients=num_weight_gradients) + print("Gradient w.r.t inputs Graph:") + print(main_b_gm_di) + print("Gradient w.r.t weights Graph:") + print(main_b_gm_dw) + # Test just to show that input/output calling conventions are correct + # the pipeline runtime will need to do this + num_total_gradients = len(main_b_gm.graph.find_nodes(op='output')[0].args[0]) + num_input_gradients = num_total_gradients - num_weight_gradients + bw_args = [x.meta['val'] for x in main_b_gm.graph.find_nodes(op="placeholder")] + input_grads_and_activations = main_b_gm_di(*bw_args) + input_grads, activations = input_grads_and_activations[:num_input_gradients], input_grads_and_activations[num_input_gradients:] + weight_grads = main_b_gm_dw(*activations) + print(f'num input grads: {len(input_grads)}') + print(f'num weight grads: {len(weight_grads)}') + +# run weight init on our sharded DTensor params +parallel_mod.to_empty(device="cuda") +parallel_mod.init_weights() + +# now let's run it +x = ( + torch.randint( + 0, + vocab_size, + (batch_size // mesh.shape[0], seqlen), + device=torch.device("cuda"), + ), +) +#out = parallel_mod(*x) +#out.backward(torch.randn_like(out)) +print("All good!")