-
Notifications
You must be signed in to change notification settings - Fork 9
Enable DualPipeV by adding a multiplexed graph #258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,10 +39,8 @@ def multiplex_fw_bw_graph( | |
| multiplexed_gm = copy.deepcopy(fw_gm) | ||
|
|
||
| # Collect all placeholder nodes from the backward graph | ||
| bw_placeholders = [] | ||
| for n in bw_gm.graph.nodes: | ||
| if n.op == "placeholder": | ||
| bw_placeholders.append(n) | ||
| bw_placeholders = bw_gm.graph.find_nodes(op="placeholder") | ||
| fw_placeholders = fw_gm.graph.find_nodes(op="placeholder") | ||
|
|
||
| # Insert backward placeholders at the beginning of the multiplexed graph | ||
| # Reversed order ensures correct execution sequence | ||
|
|
@@ -54,21 +52,21 @@ def multiplex_fw_bw_graph( | |
| old_node_to_new_node[n] = new_placeholder | ||
|
|
||
| # Find the last placeholder and the output node in the multiplexed graph | ||
| insert_point = None | ||
| multiplexed_graph_op_node = None | ||
| for n in multiplexed_gm.graph.nodes: | ||
| if n.op == "placeholder": | ||
| insert_point = n | ||
| if n.op == "output": | ||
| multiplexed_graph_op_node = n | ||
| multiplxed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder") | ||
| assert len(multiplxed_gm_placeholders) == ( | ||
| len(fw_placeholders) + len(bw_placeholders) | ||
| ) | ||
| insert_point = multiplxed_gm_placeholders[-1] | ||
|
|
||
| # Copy all computation nodes from backward graph into multiplexed graph | ||
| bw_graph_op_node = None | ||
| fw_outputs = fw_gm.graph.find_nodes(op="output") | ||
| bw_outputs = bw_gm.graph.find_nodes(op="output") | ||
| assert len(bw_outputs) == 1 and len(fw_outputs) == 1 | ||
| bw_graph_op_node = bw_outputs[0] | ||
| for n in bw_gm.graph.nodes: | ||
| if n.op == "placeholder": | ||
| continue | ||
| if n.op == "output": | ||
| bw_graph_op_node = n | ||
| continue | ||
| with multiplexed_gm.graph.inserting_after(insert_point): | ||
| # Copy node and remap its arguments using the node mapping | ||
|
|
@@ -79,25 +77,20 @@ def multiplex_fw_bw_graph( | |
| old_node_to_new_node[n] = new_node | ||
| insert_point = new_node | ||
|
|
||
| assert bw_graph_op_node is not None | ||
| assert multiplexed_graph_op_node is not None | ||
|
|
||
| # Collect output arguments from backward graph, remapping to new nodes | ||
| bw_op_node_args = [ | ||
| old_node_to_new_node[n] if n is not None else None | ||
| for n in bw_graph_op_node.args[0] | ||
| ] | ||
|
|
||
| # Collect output arguments from forward graph | ||
| # Collect output arguments from multiplexed graph (will contain only fwd_outs) | ||
| multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output") | ||
| assert len(multiplexed_graph_outputs) == 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this assert? is this a tuple of forward outs or are you hardcoding the model arch into the infra?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| multiplexed_graph_op_node = multiplexed_graph_outputs[0] | ||
| fw_op_node_args = list(multiplexed_graph_op_node.args[0]) | ||
|
|
||
| # Remove the old output node and create new combined output | ||
| insert_point = multiplexed_graph_op_node.prev | ||
| multiplexed_gm.graph.erase_node(multiplexed_graph_op_node) | ||
|
|
||
| # Create combined output with backward outputs first, then forward outputs | ||
| with multiplexed_gm.graph.inserting_after(insert_point): | ||
| multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args) | ||
| # Update output node args to prepend backward outputs before forward outputs | ||
| multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),) | ||
|
|
||
| multiplexed_gm.graph.eliminate_dead_code() | ||
| multiplexed_gm.graph.lint() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -667,10 +667,14 @@ def apply_placement_pp( | |
| assert num_params_buffers == ( | ||
| num_params + num_buffers | ||
| ), f"num_params_buffers: {num_params_buffers}, num_params: {num_params}, num_buffers: {num_buffers}" | ||
| num_input_grads = ( | ||
| len(bw_module.graph.find_nodes(op="output")[0].args[0]) - num_params_buffers | ||
| ) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm probably getting tripped up by the naming convention here,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can get the number of |
||
| print( | ||
| f"num_params_buffers: {num_params_buffers}\n" | ||
| f"num_user_outputs: {num_user_outputs}\n" | ||
| f"num_mutate_inputs: {num_mutate_inputs}\n" | ||
| f"num_input_grads: {num_input_grads}\n" | ||
| f"num_fw_outs_saved_for_bw: {num_fw_outs_saved_for_bw}\n" | ||
| f"num_symints_saved_for_bw: {num_symints_saved_for_bw}" | ||
| ) | ||
|
|
@@ -753,7 +757,6 @@ def apply_placement_pp( | |
|
|
||
| bw_dI_module: Optional[torch.fx.GraphModule] = None | ||
| bw_dW_module: Optional[torch.fx.GraphModule] = None | ||
| num_input_grads = 0 | ||
| if "split_dI_dW" in graph_passes: | ||
| from autoparallel._passes.split_di_dw_graph import split_di_dw_graph | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
multiplxed_