Skip to content

Conversation

@sanketpurandare
Copy link
Contributor

@sanketpurandare sanketpurandare commented Nov 18, 2025

Stacked PRs:


Enable DualPipeV by adding a multiplxed graph that currently doesn't do overlapping but executes the backward and forward of different stages and microbatches in a single multiplexed graph.

stack-info: PR: #258, branch: sanketpurandare/stack/1
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 18, 2025
@sanketpurandare sanketpurandare changed the title Enable DualPipeV by adding a multiplxed graph Enable DualPipeV by adding a multiplexed graph Nov 18, 2025
insert_point = n
if n.op == "output":
multiplexed_graph_op_node = n
multiplxed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo multiplxed_

# Collect output arguments from forward graph
# Collect output arguments from multiplexed graph (will contain only fwd_outs)
multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output")
assert len(multiplexed_graph_outputs) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this assert? is this a tuple of forward outs or are you hardcoding the model arch into the infra?

Copy link
Contributor Author

@sanketpurandare sanketpurandare Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

graph.find_nodes() returns a list of nodes with the filter criteria, this assert is for sanity check that there is exactly one output node. It says nothing about the args of the output node

) -> tuple[list[Any], list[Any], Any, tuple[tuple[Any], tuple[Any]]]:
multiplexed_outs = fx.Interpreter(multiplexed_fw_bw_module).boxed_run(bw_fw_args)

num_params_buffers = bw_graph_meta.num_params + bw_graph_meta.num_buffers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in a later pr, it would be nice if the decoding of the multiplexed_outs could be done by a module that was generated in the same place that generated the multiplexed graph. i mostly don't like how this function hardcodes a specific 'API' that's defined somewhere else and there isn't any enforcement that it stays consistent.

numerics_logs: Optional[list[str]] = None,
) -> None:
) -> tuple[
_PipelineScheduleRuntime,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a bit odd to me that this util function is (a) not a class method on the runtime, and (b) returns a runtime object among other things. (from the description, mostly doing setup stuff and waiting ops, it seems this function could plausibly have a null return type)

mb_index,
is_next_stage_on_this_rank,
is_prev_stage_on_this_rank,
) = _prepare_fwd_common(action, ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see why its this way now. but maybe you can
(a) get any widely used objects off ctx into local vars inside stage_forward and pass those into the util functions
(b) just access the less widely used objects directly off ctx inside the utils to reduce the number of passed around locals / API surface

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try some this refactoring in the next PR

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm!

), f"num_params_buffers: {num_params_buffers}, num_params: {num_params}, num_buffers: {num_buffers}"
num_input_grads = (
len(bw_module.graph.find_nodes(op="output")[0].args[0]) - num_params_buffers
)
Copy link
Member

@xmfan xmfan Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm probably getting tripped up by the naming convention here, num_input_grads is generally filtered from num_fwd_outputs right?

Copy link
Contributor Author

@sanketpurandare sanketpurandare Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get the number of input_grads either from the placeholders of the fwd_graph or the outputs of the bwd_graph. num_fwd_outputs will give you the num_out_grads (tangents) that will be passed in as bwd_graph inputs.

@sanketpurandare sanketpurandare merged commit 06d5f49 into main Nov 19, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants