Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions autoparallel/_passes/graph_multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@ def multiplex_fw_bw_graph(
multiplexed_gm = copy.deepcopy(fw_gm)

# Collect all placeholder nodes from the backward graph
bw_placeholders = []
for n in bw_gm.graph.nodes:
if n.op == "placeholder":
bw_placeholders.append(n)
bw_placeholders = bw_gm.graph.find_nodes(op="placeholder")
fw_placeholders = fw_gm.graph.find_nodes(op="placeholder")

# Insert backward placeholders at the beginning of the multiplexed graph
# Reversed order ensures correct execution sequence
Expand All @@ -54,21 +52,21 @@ def multiplex_fw_bw_graph(
old_node_to_new_node[n] = new_placeholder

# Find the last placeholder and the output node in the multiplexed graph
insert_point = None
multiplexed_graph_op_node = None
for n in multiplexed_gm.graph.nodes:
if n.op == "placeholder":
insert_point = n
if n.op == "output":
multiplexed_graph_op_node = n
multiplxed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo multiplxed_

assert len(multiplxed_gm_placeholders) == (
len(fw_placeholders) + len(bw_placeholders)
)
insert_point = multiplxed_gm_placeholders[-1]

# Copy all computation nodes from backward graph into multiplexed graph
bw_graph_op_node = None
fw_outputs = fw_gm.graph.find_nodes(op="output")
bw_outputs = bw_gm.graph.find_nodes(op="output")
assert len(bw_outputs) == 1 and len(fw_outputs) == 1
bw_graph_op_node = bw_outputs[0]
for n in bw_gm.graph.nodes:
if n.op == "placeholder":
continue
if n.op == "output":
bw_graph_op_node = n
continue
with multiplexed_gm.graph.inserting_after(insert_point):
# Copy node and remap its arguments using the node mapping
Expand All @@ -79,25 +77,20 @@ def multiplex_fw_bw_graph(
old_node_to_new_node[n] = new_node
insert_point = new_node

assert bw_graph_op_node is not None
assert multiplexed_graph_op_node is not None

# Collect output arguments from backward graph, remapping to new nodes
bw_op_node_args = [
old_node_to_new_node[n] if n is not None else None
for n in bw_graph_op_node.args[0]
]

# Collect output arguments from forward graph
# Collect output arguments from multiplexed graph (will contain only fwd_outs)
multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output")
assert len(multiplexed_graph_outputs) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this assert? is this a tuple of forward outs or are you hardcoding the model arch into the infra?

Copy link
Contributor Author

@sanketpurandare sanketpurandare Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

graph.find_nodes() returns a list of nodes with the filter criteria, this assert is for sanity check that there is exactly one output node. It says nothing about the args of the output node

multiplexed_graph_op_node = multiplexed_graph_outputs[0]
fw_op_node_args = list(multiplexed_graph_op_node.args[0])

# Remove the old output node and create new combined output
insert_point = multiplexed_graph_op_node.prev
multiplexed_gm.graph.erase_node(multiplexed_graph_op_node)

# Create combined output with backward outputs first, then forward outputs
with multiplexed_gm.graph.inserting_after(insert_point):
multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args)
# Update output node args to prepend backward outputs before forward outputs
multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),)

multiplexed_gm.graph.eliminate_dead_code()
multiplexed_gm.graph.lint()
Expand Down
5 changes: 4 additions & 1 deletion autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,10 +667,14 @@ def apply_placement_pp(
assert num_params_buffers == (
num_params + num_buffers
), f"num_params_buffers: {num_params_buffers}, num_params: {num_params}, num_buffers: {num_buffers}"
num_input_grads = (
len(bw_module.graph.find_nodes(op="output")[0].args[0]) - num_params_buffers
)
Copy link
Member

@xmfan xmfan Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm probably getting tripped up by the naming convention here, num_input_grads is generally filtered from num_fwd_outputs right?

Copy link
Contributor Author

@sanketpurandare sanketpurandare Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get the number of input_grads either from the placeholders of the fwd_graph or the outputs of the bwd_graph. num_fwd_outputs will give you the num_out_grads (tangents) that will be passed in as bwd_graph inputs.

print(
f"num_params_buffers: {num_params_buffers}\n"
f"num_user_outputs: {num_user_outputs}\n"
f"num_mutate_inputs: {num_mutate_inputs}\n"
f"num_input_grads: {num_input_grads}\n"
f"num_fw_outs_saved_for_bw: {num_fw_outs_saved_for_bw}\n"
f"num_symints_saved_for_bw: {num_symints_saved_for_bw}"
)
Expand Down Expand Up @@ -753,7 +757,6 @@ def apply_placement_pp(

bw_dI_module: Optional[torch.fx.GraphModule] = None
bw_dW_module: Optional[torch.fx.GraphModule] = None
num_input_grads = 0
if "split_dI_dW" in graph_passes:
from autoparallel._passes.split_di_dw_graph import split_di_dw_graph

Expand Down
Loading