diff --git a/autoparallel/_passes/graph_multiplex.py b/autoparallel/_passes/graph_multiplex.py index e1eadc8..387ad9f 100644 --- a/autoparallel/_passes/graph_multiplex.py +++ b/autoparallel/_passes/graph_multiplex.py @@ -39,10 +39,8 @@ def multiplex_fw_bw_graph( multiplexed_gm = copy.deepcopy(fw_gm) # Collect all placeholder nodes from the backward graph - bw_placeholders = [] - for n in bw_gm.graph.nodes: - if n.op == "placeholder": - bw_placeholders.append(n) + bw_placeholders = bw_gm.graph.find_nodes(op="placeholder") + fw_placeholders = fw_gm.graph.find_nodes(op="placeholder") # Insert backward placeholders at the beginning of the multiplexed graph # Reversed order ensures correct execution sequence @@ -54,21 +52,21 @@ def multiplex_fw_bw_graph( old_node_to_new_node[n] = new_placeholder # Find the last placeholder and the output node in the multiplexed graph - insert_point = None - multiplexed_graph_op_node = None - for n in multiplexed_gm.graph.nodes: - if n.op == "placeholder": - insert_point = n - if n.op == "output": - multiplexed_graph_op_node = n + multiplxed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder") + assert len(multiplxed_gm_placeholders) == ( + len(fw_placeholders) + len(bw_placeholders) + ) + insert_point = multiplxed_gm_placeholders[-1] # Copy all computation nodes from backward graph into multiplexed graph - bw_graph_op_node = None + fw_outputs = fw_gm.graph.find_nodes(op="output") + bw_outputs = bw_gm.graph.find_nodes(op="output") + assert len(bw_outputs) == 1 and len(fw_outputs) == 1 + bw_graph_op_node = bw_outputs[0] for n in bw_gm.graph.nodes: if n.op == "placeholder": continue if n.op == "output": - bw_graph_op_node = n continue with multiplexed_gm.graph.inserting_after(insert_point): # Copy node and remap its arguments using the node mapping @@ -79,25 +77,20 @@ def multiplex_fw_bw_graph( old_node_to_new_node[n] = new_node insert_point = new_node - assert bw_graph_op_node is not None - assert multiplexed_graph_op_node is not None - # Collect output arguments from backward graph, remapping to new nodes bw_op_node_args = [ old_node_to_new_node[n] if n is not None else None for n in bw_graph_op_node.args[0] ] - # Collect output arguments from forward graph + # Collect output arguments from multiplexed graph (will contain only fwd_outs) + multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output") + assert len(multiplexed_graph_outputs) == 1 + multiplexed_graph_op_node = multiplexed_graph_outputs[0] fw_op_node_args = list(multiplexed_graph_op_node.args[0]) - # Remove the old output node and create new combined output - insert_point = multiplexed_graph_op_node.prev - multiplexed_gm.graph.erase_node(multiplexed_graph_op_node) - - # Create combined output with backward outputs first, then forward outputs - with multiplexed_gm.graph.inserting_after(insert_point): - multiplexed_gm.graph.output(bw_op_node_args + fw_op_node_args) + # Update output node args to prepend backward outputs before forward outputs + multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),) multiplexed_gm.graph.eliminate_dead_code() multiplexed_gm.graph.lint() diff --git a/autoparallel/api.py b/autoparallel/api.py index 621ee0f..c9995aa 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -667,10 +667,14 @@ def apply_placement_pp( assert num_params_buffers == ( num_params + num_buffers ), f"num_params_buffers: {num_params_buffers}, num_params: {num_params}, num_buffers: {num_buffers}" + num_input_grads = ( + len(bw_module.graph.find_nodes(op="output")[0].args[0]) - num_params_buffers + ) print( f"num_params_buffers: {num_params_buffers}\n" f"num_user_outputs: {num_user_outputs}\n" f"num_mutate_inputs: {num_mutate_inputs}\n" + f"num_input_grads: {num_input_grads}\n" f"num_fw_outs_saved_for_bw: {num_fw_outs_saved_for_bw}\n" f"num_symints_saved_for_bw: {num_symints_saved_for_bw}" ) @@ -753,7 +757,6 @@ def apply_placement_pp( bw_dI_module: Optional[torch.fx.GraphModule] = None bw_dW_module: Optional[torch.fx.GraphModule] = None - num_input_grads = 0 if "split_dI_dW" in graph_passes: from autoparallel._passes.split_di_dw_graph import split_di_dw_graph diff --git a/autoparallel/graph_pp_runner.py b/autoparallel/graph_pp_runner.py index 2f089d8..3ae965c 100644 --- a/autoparallel/graph_pp_runner.py +++ b/autoparallel/graph_pp_runner.py @@ -48,6 +48,34 @@ class GraphMeta: num_input_grads: int +def get_multiplexed_graph_callables( + stage_graphs: dict[int, GraphCallables] +) -> dict[tuple[int, int], fx.GraphModule]: + """Generate multiplexed graph modules that fuse forward and backward passes from different stages. + + Creates fused modules for all stage pairs where fw_stage_idx != bw_stage_idx. This enables + pipeline schedules (e.g., ZeroBubble) to overlap computation and reduce bubbles. + + Args: + stage_graphs: Mapping from stage index to GraphCallables containing forward/backward modules. + + Returns: + Mapping from (fw_stage_idx, bw_stage_idx) to fused GraphModule that executes + forward from fw_stage_idx and backward from bw_stage_idx. + """ + from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph + + multiplexed_graph_callables: dict[tuple[int, int], torch.fx.GraphModule] = {} + for bw_stage_idx, bw_stage_graph_callables in stage_graphs.items(): + for fw_stage_idx, fw_stage_graph_callables in stage_graphs.items(): + if bw_stage_idx != fw_stage_idx: + fw_bw_module = multiplex_fw_bw_graph( + fw_stage_graph_callables.fw, bw_stage_graph_callables.full_bw + ) + multiplexed_graph_callables[(fw_stage_idx, bw_stage_idx)] = fw_bw_module + return multiplexed_graph_callables + + class GraphPipelineStage(PipelineStage): def __init__( self, @@ -128,7 +156,7 @@ def _run_fw_module( fw_outputs = debug_interpreter.boxed_run(fw_args) numerics_logs += debug_interpreter.get_logs() else: - fw_outputs = torch.fx.Interpreter(fw_module).boxed_run(fw_args) + fw_outputs = fx.Interpreter(fw_module).boxed_run(fw_args) num_inner_fwd_outputs = graph_meta.num_mutate_inputs + graph_meta.num_user_outputs saved_intermediates = fw_outputs[num_inner_fwd_outputs:] @@ -147,7 +175,7 @@ def _run_fw_module( def _run_full_bw_module( bw_module: fx.GraphModule, graph_meta: GraphMeta, bw_args ) -> tuple[list[Any], list[Any]]: - bw_outputs = torch.fx.Interpreter(bw_module).boxed_run(bw_args) + bw_outputs = fx.Interpreter(bw_module).boxed_run(bw_args) num_params_buffers = graph_meta.num_params + graph_meta.num_buffers param_buffer_grads = bw_outputs[:num_params_buffers] input_grads = bw_outputs[num_params_buffers:] @@ -157,7 +185,7 @@ def _run_full_bw_module( def _run_dI_bw_module( bw_dI_module: fx.GraphModule, graph_meta: GraphMeta, bw_dI_args ) -> tuple[list[Any], list[Any]]: - inp_grads_and_activations = torch.fx.Interpreter(bw_dI_module).boxed_run(bw_dI_args) + inp_grads_and_activations = fx.Interpreter(bw_dI_module).boxed_run(bw_dI_args) inp_grads, activations = inp_grads_and_activations[ : graph_meta.num_input_grads ], list(inp_grads_and_activations[graph_meta.num_input_grads :]) @@ -167,24 +195,56 @@ def _run_dI_bw_module( def _run_dW_bw_module( bw_dW_module: fx.GraphModule, graph_meta: GraphMeta, bw_dW_args ) -> list[Any]: - param_buffer_grads = torch.fx.Interpreter(bw_dW_module).boxed_run(bw_dW_args) + param_buffer_grads = fx.Interpreter(bw_dW_module).boxed_run(bw_dW_args) return param_buffer_grads def _run_unshard_module( unshard_module: fx.GraphModule, graph_meta: GraphMeta, unshard_args ) -> list[Any]: - unsharded_params = torch.fx.Interpreter(unshard_module).boxed_run(unshard_args) + unsharded_params = fx.Interpreter(unshard_module).boxed_run(unshard_args) return unsharded_params def _run_reduce_grad_module( reduce_grad_module: fx.GraphModule, graph_meta: GraphMeta, reduce_grad_args ) -> list[Any]: - sharded_grads = torch.fx.Interpreter(reduce_grad_module).boxed_run(reduce_grad_args) + sharded_grads = fx.Interpreter(reduce_grad_module).boxed_run(reduce_grad_args) return sharded_grads +def _run_multiplexed_fw_bw_module( + multiplexed_fw_bw_module: fx.GraphModule, + fw_graph_meta: GraphMeta, + bw_graph_meta: GraphMeta, + bw_fw_args, +) -> tuple[list[Any], list[Any], Any, tuple[tuple[Any], tuple[Any]]]: + multiplexed_outs = fx.Interpreter(multiplexed_fw_bw_module).boxed_run(bw_fw_args) + + num_params_buffers = bw_graph_meta.num_params + bw_graph_meta.num_buffers + num_bw_outs = bw_graph_meta.num_input_grads + num_params_buffers + bw_outputs = multiplexed_outs[:num_bw_outs] + param_buffer_grads = bw_outputs[:num_params_buffers] + input_grads = bw_outputs[num_params_buffers:] + + fw_outputs = multiplexed_outs[num_bw_outs:] + num_inner_fwd_outputs = ( + fw_graph_meta.num_mutate_inputs + fw_graph_meta.num_user_outputs + ) + saved_intermediates = fw_outputs[num_inner_fwd_outputs:] + num_tensors_for_backward = ( + len(saved_intermediates) - fw_graph_meta.num_symints_saved_for_bw + ) + tensors_for_backward = saved_intermediates[:num_tensors_for_backward] + non_tensors_for_backward = saved_intermediates[num_tensors_for_backward:] + save_for_backward = (tensors_for_backward, non_tensors_for_backward) + user_outputs = fw_outputs[fw_graph_meta.num_mutate_inputs : num_inner_fwd_outputs] + if len(user_outputs) == 1: + user_outputs = user_outputs[0] + + return input_grads, param_buffer_grads, user_outputs, save_for_backward + + def _get_stage_from_action( action: _Action, ctx: _PipelineContext, @@ -210,19 +270,40 @@ def _get_stage_from_action( return schedule, stage_index_to_stage, stage -def stage_forward( +def _prepare_fwd_common( action: _Action, ctx: _PipelineContext, - numerics_logs: Optional[list[str]] = None, -) -> None: +) -> tuple[ + _PipelineScheduleRuntime, + dict[int, GraphPipelineStage], + GraphPipelineStage, + int, + bool, + bool, +]: + """Common setup for forward stage: retrieve stage info and handle recv ops. + + This function performs the shared initialization logic for forward operations, + including waiting for activation receives from the previous pipeline stage. + + Args: + action: The forward action to execute, containing the stage index and microbatch index. + ctx: The pipeline context containing the schedule and pipeline state. + + Returns: + A tuple containing: + - schedule: The pipeline schedule runtime object managing the execution. + - stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. + - stage: The GraphPipelineStage for which forward is being computed. + - mb_index: The microbatch index being processed. + - is_next_stage_on_this_rank: True if stage_index + 1 exists on this rank (V-schedule). + - is_prev_stage_on_this_rank: True if stage_index - 1 exists on this rank (V-schedule). + """ schedule, stage_index_to_stage, stage = _get_stage_from_action(action, ctx) stage_index = stage.stage_index mb_index = action.microbatch_index assert mb_index is not None - fwd_recv_ops = schedule.fwd_recv_ops - arg_mbs = ctx.arg_mbs - kwarg_mbs = ctx.kwarg_mbs is_next_stage_on_this_rank = stage_index + 1 in stage_index_to_stage is_prev_stage_on_this_rank = stage_index - 1 in stage_index_to_stage @@ -232,13 +313,41 @@ def stage_forward( # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) and not is_prev_stage_on_this_rank ): + fwd_recv_ops = schedule.fwd_recv_ops assert ( stage_index, mb_index, ) in fwd_recv_ops, f"Computing {action=} before receiving input" - _wait_batch_p2p(fwd_recv_ops.pop((stage_index, mb_index))) + return ( + schedule, + stage_index_to_stage, + stage, + mb_index, + is_next_stage_on_this_rank, + is_prev_stage_on_this_rank, + ) + + +def _prepare_fwd_args( + stage: GraphPipelineStage, + mb_index: int, + ctx: _PipelineContext, +) -> list[Any]: + """Prepare forward args from user inputs or received activations. + + Args: + stage: The GraphPipelineStage for which to prepare forward arguments. + mb_index: The microbatch index being processed. + ctx: The pipeline context containing arg_mbs, kwarg_mbs, and target_mbs. + + Returns: + List of forward arguments including unsharded_params, buffers, and composite_args. + """ + arg_mbs = ctx.arg_mbs + kwarg_mbs = ctx.kwarg_mbs + args = arg_mbs[mb_index] # type: ignore[index] kwargs = kwarg_mbs[mb_index] # type: ignore[index] assert not kwargs # TODO: if kwargs can always be ignored, maybe remove? @@ -257,19 +366,41 @@ def stage_forward( composite_args = composite_args + (ctx.target_mbs[mb_index],) # type: ignore[index] # stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args? - logger.debug( - "GraphPPRunner running action %s", - action, - ) fw_args = [ *stage.state["unsharded_params"], *stage.state["buffers"], *composite_args, ] del composite_args - output, saved_intermediates = _run_fw_module( - stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs - ) + return fw_args + + +def _post_fwd_common( + stage: GraphPipelineStage, + mb_index: int, + output: Any, + saved_intermediates: tuple[tuple[Any], tuple[Any]], + schedule: _PipelineScheduleRuntime, + stage_index_to_stage: dict[int, GraphPipelineStage], + ctx: _PipelineContext, + is_next_stage_on_this_rank: bool, +) -> None: + """Common post-processing after forward pass: cache outputs and propagate. + + This function handles the shared finalization logic for forward operations, + including normalizing outputs, caching for backward, validating outputs, + and propagating activations to the next pipeline stage. + + Args: + stage: The stage that just completed forward computation. + mb_index: The microbatch index that was processed. + output: The output from the forward pass. + saved_intermediates: The intermediates saved for backward pass. + schedule: The pipeline schedule runtime object. + stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. + ctx: The pipeline context. + is_next_stage_on_this_rank: True if the next stage exists on this rank. + """ # See [Note: pipeline model output type] output_tuple = _normalize_model_output_as_tuple(output) @@ -289,7 +420,45 @@ def stage_forward( # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank # see [Note: V-schedule special case] if is_next_stage_on_this_rank: - stage_index_to_stage[stage_index + 1].set_local_fwd_input(output, mb_index) + stage_index_to_stage[stage.stage_index + 1].set_local_fwd_input( + output, mb_index + ) + + +def stage_forward( + action: _Action, + ctx: _PipelineContext, + numerics_logs: Optional[list[str]] = None, +) -> None: + ( + schedule, + stage_index_to_stage, + stage, + mb_index, + is_next_stage_on_this_rank, + is_prev_stage_on_this_rank, + ) = _prepare_fwd_common(action, ctx) + + fw_args = _prepare_fwd_args(stage, mb_index, ctx) + + logger.debug( + "GraphPPRunner running action %s", + action, + ) + output, saved_intermediates = _run_fw_module( + stage.graph_callables.fw, stage.graph_meta, fw_args, numerics_logs=numerics_logs + ) + + _post_fwd_common( + stage, + mb_index, + output, + saved_intermediates, + schedule, + stage_index_to_stage, + ctx, + is_next_stage_on_this_rank, + ) def _prepare_backward_common( @@ -317,56 +486,54 @@ def _prepare_backward_common( A tuple containing: - schedule: The pipeline schedule runtime object managing the execution. - stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. - - backward_stage: The GraphPipelineStage for which backward is being computed. - - backward_mb_index: The microbatch index being processed. + - bw_stage: The GraphPipelineStage for which backward is being computed. + - bw_mb_index: The microbatch index being processed. - is_next_stage_on_this_rank: True if stage_index + 1 exists on this rank (V-schedule). - is_prev_stage_on_this_rank: True if stage_index - 1 exists on this rank (V-schedule). """ - schedule, stage_index_to_stage, backward_stage = _get_stage_from_action(action, ctx) + schedule, stage_index_to_stage, bw_stage = _get_stage_from_action(action, ctx) - backward_mb_index = action.microbatch_index - assert backward_mb_index is not None - is_next_stage_on_this_rank = backward_stage.stage_index + 1 in stage_index_to_stage - is_prev_stage_on_this_rank = backward_stage.stage_index - 1 in stage_index_to_stage + bw_mb_index = action.microbatch_index + assert bw_mb_index is not None + is_next_stage_on_this_rank = bw_stage.stage_index + 1 in stage_index_to_stage + is_prev_stage_on_this_rank = bw_stage.stage_index - 1 in stage_index_to_stage - if not backward_stage.is_last and not is_next_stage_on_this_rank: + if not bw_stage.is_last and not is_next_stage_on_this_rank: bwd_recv_ops = schedule.bwd_recv_ops assert ( - backward_stage.stage_index, - backward_mb_index, + bw_stage.stage_index, + bw_mb_index, ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" - _wait_batch_p2p( - bwd_recv_ops.pop((backward_stage.stage_index, backward_mb_index)) - ) + _wait_batch_p2p(bwd_recv_ops.pop((bw_stage.stage_index, bw_mb_index))) - schedule.backward_counter[backward_stage.stage_index] += 1 + schedule.backward_counter[bw_stage.stage_index] += 1 return ( schedule, stage_index_to_stage, - backward_stage, - backward_mb_index, + bw_stage, + bw_mb_index, is_next_stage_on_this_rank, is_prev_stage_on_this_rank, ) def _prepare_backward_args( - backward_stage: GraphPipelineStage, - backward_mb_index: int, + bw_stage: GraphPipelineStage, + bw_mb_index: int, ) -> list[Any]: """Prepare backward kwargs from cached forward outputs.""" ( stage_output, saved_intermediates, - ) = backward_stage.fwd_cache.pop(backward_mb_index) + ) = bw_stage.fwd_cache.pop(bw_mb_index) - if backward_stage.is_last: + if bw_stage.is_last: assert len(stage_output) == 1 loss = stage_output[0] tangents = (torch.ones_like(loss),) else: - tangents = backward_stage._retrieve_recv_grads(backward_mb_index) + tangents = bw_stage._retrieve_recv_grads(bw_mb_index) tensors_for_backward, non_tensors_for_backward = saved_intermediates @@ -379,6 +546,42 @@ def _prepare_backward_args( return bw_args +def _post_backward_common( + bw_stage: GraphPipelineStage, + bw_mb_index: int, + input_grads: list[Any], + stage_index_to_stage: dict[int, GraphPipelineStage], + is_prev_stage_on_this_rank: bool, +) -> None: + """Common post-processing after backward pass: cache input grads and propagate. + + This function handles the shared finalization logic for backward operations, + including caching input gradients and propagating gradients to the previous + pipeline stage. + + Note: Gradient accumulation and scaling are NOT included here as they occur + at different points for full_backward vs split dI/dW: + - full_backward: accumulation and scaling happen immediately after backward + - split dI/dW: accumulation and scaling happen in backward_weight (dW), not backward_input (dI) + + Args: + bw_stage: The stage that just completed backward computation. + bw_mb_index: The microbatch index that was processed. + input_grads: The computed input gradients to cache. + stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. + is_prev_stage_on_this_rank: True if the previous stage exists on this rank. + """ + bw_stage.bwd_cache[bw_mb_index] = ( + tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads + ) + + if is_prev_stage_on_this_rank: + stage_index_to_stage[bw_stage.stage_index - 1].set_local_bwd_input( + bw_stage.get_local_bwd_output(bw_mb_index), + bw_mb_index, + ) + + def stage_full_backward( action: _Action, ctx: _PipelineContext, @@ -386,53 +589,51 @@ def stage_full_backward( ( schedule, stage_index_to_stage, - backward_stage, - backward_mb_index, + bw_stage, + bw_mb_index, is_next_stage_on_this_rank, is_prev_stage_on_this_rank, ) = _prepare_backward_common(action, ctx) last_backward = ( - schedule.backward_counter[backward_stage.stage_index] - == schedule._n_microbatches + schedule.backward_counter[bw_stage.stage_index] == schedule._n_microbatches ) grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 - if not backward_stage.has_backward: + if not bw_stage.has_backward: logger.debug("Returning early for backward stage") return - bwd_args = _prepare_backward_args(backward_stage, backward_mb_index) + bw_args = _prepare_backward_args(bw_stage, bw_mb_index) logger.debug( "GraphPPRunner running action %s", action, ) input_grads, param_buffer_grads = _run_full_bw_module( - backward_stage.graph_callables.full_bw, backward_stage.graph_meta, bwd_args + bw_stage.graph_callables.full_bw, bw_stage.graph_meta, bw_args ) - backward_stage.bwd_cache[backward_mb_index] = ( - tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads + bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads) + + _post_backward_common( + bw_stage, + bw_mb_index, + input_grads, + stage_index_to_stage, + is_prev_stage_on_this_rank, ) - backward_stage._accumulate_stage_unsharded_grads(param_buffer_grads) if last_backward: - backward_stage.scale_grads(grad_scale_factor) - - if is_prev_stage_on_this_rank: - stage_index_to_stage[backward_stage.stage_index - 1].set_local_bwd_input( - backward_stage.get_local_bwd_output(backward_mb_index), - backward_mb_index, - ) + bw_stage.scale_grads(grad_scale_factor) def stage_backward_input( action: _Action, ctx: _PipelineContext, ) -> None: - schedule, stage_index_to_stage, backward_stage = _get_stage_from_action(action, ctx) + schedule, stage_index_to_stage, bw_stage = _get_stage_from_action(action, ctx) - if backward_stage.is_first and backward_stage.graph_callables.bw_dI is None: + if bw_stage.is_first and bw_stage.graph_callables.bw_dI is None: # First stage does not have bw_dI graph since usually the inputs of the first stage do not require gradients # Hence, we do not do a split_dI_dW pass, and call full backward instead during dI action logger.debug( @@ -451,50 +652,50 @@ def stage_backward_input( ( schedule, stage_index_to_stage, - backward_stage, - backward_mb_index, + bw_stage, + bw_mb_index, is_next_stage_on_this_rank, is_prev_stage_on_this_rank, ) = _prepare_backward_common(action, ctx) - if not backward_stage.has_backward: + if not bw_stage.has_backward: logger.debug("Returning early for backward stage") return - bwd_args = _prepare_backward_args(backward_stage, backward_mb_index) + bw_args = _prepare_backward_args(bw_stage, bw_mb_index) logger.debug( "GraphPPRunner running action %s", action, ) - assert backward_stage.graph_callables.bw_dI is not None + assert bw_stage.graph_callables.bw_dI is not None input_grads, activations_for_backward = _run_dI_bw_module( - backward_stage.graph_callables.bw_dI, backward_stage.graph_meta, bwd_args - ) - backward_stage.bwd_cache[backward_mb_index] = ( - tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads + bw_stage.graph_callables.bw_dI, bw_stage.graph_meta, bw_args ) - backward_stage.bwd_activation_cache[backward_mb_index] = ( + + bw_stage.bwd_activation_cache[bw_mb_index] = ( tuple(activations_for_backward) if not isinstance(activations_for_backward, tuple) else activations_for_backward ) - if is_prev_stage_on_this_rank: - stage_index_to_stage[backward_stage.stage_index - 1].set_local_bwd_input( - backward_stage.get_local_bwd_output(backward_mb_index), - backward_mb_index, - ) + _post_backward_common( + bw_stage, + bw_mb_index, + input_grads, + stage_index_to_stage, + is_prev_stage_on_this_rank, + ) def stage_backward_weight( action: _Action, ctx: _PipelineContext, ) -> None: - schedule, stage_index_to_stage, backward_stage = _get_stage_from_action(action, ctx) - backward_mb_index = action.microbatch_index - assert backward_mb_index is not None - if backward_stage.is_first and backward_stage.graph_callables.bw_dW is None: + schedule, stage_index_to_stage, bw_stage = _get_stage_from_action(action, ctx) + bw_mb_index = action.microbatch_index + assert bw_mb_index is not None + if bw_stage.is_first and bw_stage.graph_callables.bw_dW is None: # First stage does not have bw_dW graph since usually the inputs of the first stage do not require gradients # Hence, we do not do a split_dI_dW pass, and call full backward instead during dI action # which also performs dW implicitly, hence we skip this step. @@ -505,32 +706,112 @@ def stage_backward_weight( return last_backward = ( - schedule.backward_counter[backward_stage.stage_index] - == schedule._n_microbatches + schedule.backward_counter[bw_stage.stage_index] == schedule._n_microbatches ) grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 - if not backward_stage.has_backward: + if not bw_stage.has_backward: logger.debug("Returning early for backward stage") return - activations_for_backward = backward_stage.bwd_activation_cache.pop( - backward_mb_index - ) + activations_for_backward = bw_stage.bwd_activation_cache.pop(bw_mb_index) logger.debug( "GraphPPRunner running action %s", action, ) - bwd_args = list(activations_for_backward) + bw_args = list(activations_for_backward) del activations_for_backward - assert backward_stage.graph_callables.bw_dW is not None + assert bw_stage.graph_callables.bw_dW is not None param_buffer_grads = _run_dW_bw_module( - backward_stage.graph_callables.bw_dW, backward_stage.graph_meta, bwd_args + bw_stage.graph_callables.bw_dW, bw_stage.graph_meta, bw_args + ) + bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads) + + if last_backward: + bw_stage.scale_grads(grad_scale_factor) + + +def overlap_fw_bw( + multiplexed_graph_callables: dict[tuple[int, int], fx.GraphModule], + action: _Action, + ctx: _PipelineContext, +) -> None: + assert action.sub_actions is not None, "Expected sub actions for overlap callback" + fw_action = action.sub_actions[0] + bw_action = action.sub_actions[1] + + ( + schedule, + stage_index_to_stage, + fw_stage, + fw_mb_index, + fw_is_next_stage_on_this_rank, + fw_is_prev_stage_on_this_rank, + ) = _prepare_fwd_common(fw_action, ctx) + + ( + _, + _, + bw_stage, + bw_mb_index, + bw_is_next_stage_on_this_rank, + bw_is_prev_stage_on_this_rank, + ) = _prepare_backward_common(bw_action, ctx) + + last_backward = ( + schedule.backward_counter[bw_stage.stage_index] == schedule._n_microbatches + ) + grad_scale_factor = schedule._n_microbatches if schedule.scale_grads else 1 + + if not bw_stage.has_backward: + logger.debug("Returning early for backward stage") + return + + fw_args = _prepare_fwd_args(fw_stage, fw_mb_index, ctx) + bw_args = _prepare_backward_args(bw_stage, bw_mb_index) + bw_fw_args = bw_args + fw_args + del bw_args, fw_args + multiplexed_fw_bw_module = multiplexed_graph_callables.get( + (fw_action.stage_index, bw_action.stage_index) + ) + assert ( + multiplexed_fw_bw_module is not None + ), "Expected multiplexed graph callables for overlap callback" + logger.debug( + "GraphPPRunner running action %s", + action, + ) + ( + input_grads, + param_buffer_grads, + output, + saved_intermediates, + ) = _run_multiplexed_fw_bw_module( + multiplexed_fw_bw_module, fw_stage.graph_meta, bw_stage.graph_meta, bw_fw_args + ) + + bw_stage._accumulate_stage_unsharded_grads(param_buffer_grads) + + _post_fwd_common( + fw_stage, + fw_mb_index, + output, + saved_intermediates, + schedule, + stage_index_to_stage, + ctx, + fw_is_next_stage_on_this_rank, + ) + _post_backward_common( + bw_stage, + bw_mb_index, + input_grads, + stage_index_to_stage, + bw_is_prev_stage_on_this_rank, ) - backward_stage._accumulate_stage_unsharded_grads(param_buffer_grads) if last_backward: - backward_stage.scale_grads(grad_scale_factor) + bw_stage.scale_grads(grad_scale_factor) def stage_unshard( diff --git a/examples/example_ds3_pp.py b/examples/example_ds3_pp.py index 868e3a6..fbb7d0f 100644 --- a/examples/example_ds3_pp.py +++ b/examples/example_ds3_pp.py @@ -19,6 +19,7 @@ BACKWARD_WEIGHT, FORWARD, FULL_BACKWARD, + OVERLAP_F_B, REDUCE_GRAD, RESHARD, UNSHARD, @@ -47,6 +48,8 @@ GraphMeta, GraphPipelineStage, GraphPPRunner, + get_multiplexed_graph_callables, + overlap_fw_bw, stage_backward_input, stage_backward_weight, stage_forward, @@ -553,6 +556,11 @@ def last_stage_inp_with_loss_fn(): schedule.register_custom_function(UNSHARD, stage_unshard) schedule.register_custom_function(BACKWARD_INPUT, stage_backward_input) schedule.register_custom_function(BACKWARD_WEIGHT, stage_backward_weight) + if schedule_name == "DualPipeV": + multiplexed_graph_callables = get_multiplexed_graph_callables(stage_graphs) + schedule.register_custom_function( + OVERLAP_F_B, functools.partial(overlap_fw_bw, multiplexed_graph_callables) + ) # Step 7. Register the schedule with the graph runner @@ -631,8 +639,8 @@ def last_stage_inp_with_loss_fn(): parser.add_argument( "--schedule-name", type=str, - default="ZBVZeroBubble", - choices=["Interleaved1F1B", "ZBVZeroBubble"], + default="DualPipeV", + choices=["Interleaved1F1B", "ZBVZeroBubble", "DualPipeV"], help="Schedule to use for PP", ) args = parser.parse_args() diff --git a/examples/example_pp_graph_passes.py b/examples/example_pp_graph_passes.py index c8e7abe..e6d8e98 100644 --- a/examples/example_pp_graph_passes.py +++ b/examples/example_pp_graph_passes.py @@ -4,15 +4,20 @@ # LICENSE file in the root directory of this source tree. from contextlib import nullcontext -from typing import Callable +from typing import Callable, Union import torch -from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.fake_tensor import ( + FakeTensor, + FakeTensorMode, + unset_fake_temporarily, +) from torch.distributed.tensor import DeviceMesh, DTensor from torch.distributed.tensor.placement_types import Replicate, Shard from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.testing._internal.distributed.fake_pg import FakeStore +from autoparallel._passes.graph_multiplex import multiplex_fw_bw_graph from autoparallel._testing.models.dsv3 import ( DeepSeekV3Model, DeepSeekV3ModelArgs, @@ -27,11 +32,63 @@ _run_dW_bw_module, _run_full_bw_module, _run_fw_module, + _run_multiplexed_fw_bw_module, _run_reduce_grad_module, _run_unshard_module, ) +def compare_tuples(tuple1: Union[list, tuple], tuple2: Union[list, tuple]) -> bool: + """Compare two tuples element-by-element with specialized comparison logic. + + For each element pair: + - If both are FakeTensor: compare shape, stride, and dtype + - If both are Tensor: use torch.allclose for numerical comparison + - Otherwise: skip the comparison (always matches) + + Args: + tuple1: First tuple to compare. + tuple2: Second tuple to compare. + + Returns: + True if all comparable elements match according to the rules above, False otherwise. + """ + if len(tuple1) != len(tuple2): + return False + with unset_fake_temporarily(): + for elem1, elem2 in zip(tuple1, tuple2): + # Check if both are FakeTensor + if isinstance(elem1, FakeTensor): + if not isinstance(elem2, FakeTensor): + return False + if elem1.dtype != elem2.dtype: + return False + # Try to compare strides or shape, but skip if it would trigger data-dependent guards + try: + if elem1.shape != elem2.shape: + return False + if elem1.stride() != elem2.stride(): + return False + except ( + torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode + ): + # Skip stride or shape comparison for symbolic shapes + pass + # Check if both are regular Tensor (but not FakeTensor) + elif isinstance(elem1, torch.Tensor) and not isinstance(elem1, FakeTensor): + if not ( + isinstance(elem2, torch.Tensor) + and not isinstance(elem2, FakeTensor) + ): + return False + # Use torch.allclose for numerical comparison + if not torch.allclose(elem1, elem2): + return False + # Otherwise, skip the comparison (neither or mismatched types) + + return True + + def _get_pp_module_and_graphs( model: torch.nn.Module, mesh: DeviceMesh, @@ -105,25 +162,30 @@ def _get_fw_inputs( v.to_local() if isinstance(v, DTensor) else v for k, v in dict(pp_mod.named_buffers(remove_duplicate=False)).items() ] - return [sharded_params, buffers, x] + return (sharded_params, buffers, x) # Symbolically evaluate in case you want to test running a graph bigger than your gpu -def test_graph_partition( - model: torch.nn.Module, - mesh: DeviceMesh, - tracing_input_fn: Callable, - eval_input_fn: Callable, - fake_evaluate: bool = True, - use_loss_fn: bool = True, -): +def _run_graph_test( + pp_mod: torch.nn.Module, + graph_modules: GraphCallables, + graph_meta: GraphMeta, + sharded_params: list[torch.Tensor], + buffers: list[torch.Tensor], + x: list[torch.Tensor], + fake_evaluate: bool, + use_fsdp_collectives: bool, + use_split_dI_dW: bool, + use_multiplexed_graph: bool, +) -> None: + """Execute forward and backward passes with specified graph options.""" + if use_multiplexed_graph: + multiplexed_fw_bw_module = multiplex_fw_bw_graph( + graph_modules.fw, graph_modules.full_bw + ) - pp_mod, graph_modules, graph_meta = _get_pp_module_and_graphs( - model, mesh, tracing_input_fn, use_loss_fn=use_loss_fn - ) - sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) with ( FakeTensorMode( allow_non_fake_inputs=True, @@ -132,20 +194,58 @@ def test_graph_partition( if fake_evaluate else nullcontext() ): - # # now let's run it with torch.no_grad(): - fw_args = [*sharded_params, *buffers, *x] + # Forward pass setup + if use_fsdp_collectives: + unshard_args = list(sharded_params) + assert graph_modules.unshard is not None + unsharded_params = _run_unshard_module( + graph_modules.unshard, graph_meta, unshard_args + ) + fw_args = [*unsharded_params, *buffers, *x] + else: + fw_args = [*sharded_params, *buffers, *x] + if use_multiplexed_graph: + m_fw_args = list(fw_args) + # Forward pass loss_or_output, saved_intermediates = _run_fw_module( graph_modules.fw, graph_meta, fw_args ) tangents = [torch.ones_like(loss_or_output)] tensors_for_backward, non_tensors_for_backward = saved_intermediates + # Backward pass setup bw_args = [ *non_tensors_for_backward, *tensors_for_backward, *tangents, ] + if use_multiplexed_graph: + m_bw_args = list(bw_args) + joint_args = m_bw_args + m_fw_args + del m_bw_args, m_fw_args + ( + m_input_grads, + m_param_buffer_grads, + m_loss_or_output, + m_saved_intermediates, + ) = _run_multiplexed_fw_bw_module( + multiplexed_fw_bw_module, graph_meta, graph_meta, joint_args + ) + ( + m_tensors_for_backward, + m_non_tensors_for_backward, + ) = m_saved_intermediates + assert compare_tuples((m_loss_or_output,), (loss_or_output,)) + assert compare_tuples(m_tensors_for_backward, tensors_for_backward) + assert compare_tuples( + m_non_tensors_for_backward, non_tensors_for_backward + ) + del ( + m_non_tensors_for_backward, + m_tensors_for_backward, + m_loss_or_output, + ) del ( tensors_for_backward, non_tensors_for_backward, @@ -153,10 +253,64 @@ def test_graph_partition( saved_intermediates, ) - input_grads, param_buffer_grads = _run_full_bw_module( - graph_modules.full_bw, graph_meta, bw_args - ) + # Backward pass + if use_split_dI_dW: + assert graph_modules.bw_dI is not None + input_grads, activations_for_backward = _run_dI_bw_module( + graph_modules.bw_dI, graph_meta, bw_args + ) + dw_args = list(activations_for_backward) + del activations_for_backward + assert graph_modules.bw_dW is not None + param_buffer_grads = _run_dW_bw_module( + graph_modules.bw_dW, graph_meta, dw_args + ) + else: + input_grads, param_buffer_grads = _run_full_bw_module( + graph_modules.full_bw, graph_meta, bw_args + ) + if use_multiplexed_graph: + assert compare_tuples(m_param_buffer_grads, param_buffer_grads) + assert compare_tuples(m_input_grads, input_grads) + del m_param_buffer_grads, m_input_grads + assert len(param_buffer_grads) == (len(sharded_params) + len(buffers)) + unsharded_grads = list(param_buffer_grads[: len(sharded_params)]) + del param_buffer_grads, input_grads + # Gradient reduction (if using FSDP collectives) + if use_fsdp_collectives: + assert graph_modules.reduce_grad is not None + sharded_grads = _run_reduce_grad_module( + graph_modules.reduce_grad, graph_meta, unsharded_grads + ) + else: + sharded_grads = unsharded_grads + assert len(sharded_grads) == len(sharded_params) + +def test_graph_partition( + model: torch.nn.Module, + mesh: DeviceMesh, + tracing_input_fn: Callable, + eval_input_fn: Callable, + fake_evaluate: bool = True, + use_loss_fn: bool = True, +): + pp_mod, graph_modules, graph_meta = _get_pp_module_and_graphs( + model, mesh, tracing_input_fn, use_loss_fn=use_loss_fn + ) + sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) + _run_graph_test( + pp_mod, + graph_modules, + graph_meta, + sharded_params, + buffers, + x, + fake_evaluate, + use_fsdp_collectives=False, + use_split_dI_dW=False, + use_multiplexed_graph=True, + ) print("All good!") @@ -168,7 +322,6 @@ def test_split_fsdp_collectives( fake_evaluate: bool = True, use_loss_fn: bool = True, ): - pp_mod, graph_modules, graph_meta = _get_pp_module_and_graphs( model, mesh, @@ -177,50 +330,18 @@ def test_split_fsdp_collectives( use_loss_fn=use_loss_fn, ) sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) - with ( - FakeTensorMode( - allow_non_fake_inputs=True, - shape_env=ShapeEnv(), - ) - if fake_evaluate - else nullcontext() - ): - # # now let's run it - with torch.no_grad(): - unshard_args = list(sharded_params) - assert graph_modules.unshard is not None - unsharded_params = _run_unshard_module( - graph_modules.unshard, graph_meta, unshard_args - ) - fw_args = [*unsharded_params, *buffers, *x] - loss_or_output, saved_intermediates = _run_fw_module( - graph_modules.fw, graph_meta, fw_args - ) - tangents = [torch.randn_like(loss_or_output)] - tensors_for_backward, non_tensors_for_backward = saved_intermediates - - bw_args = [ - *non_tensors_for_backward, - *tensors_for_backward, - *tangents, - ] - del ( - tensors_for_backward, - non_tensors_for_backward, - tangents, - saved_intermediates, - ) - input_grads, unsharded_param_buffer_grads = _run_full_bw_module( - graph_modules.full_bw, graph_meta, bw_args - ) - unsharded_grads = list(unsharded_param_buffer_grads[: len(sharded_params)]) - del unsharded_param_buffer_grads, input_grads - assert graph_modules.reduce_grad is not None - sharded_grads = _run_reduce_grad_module( - graph_modules.reduce_grad, graph_meta, unsharded_grads - ) - assert len(sharded_grads) == len(sharded_params) - + _run_graph_test( + pp_mod, + graph_modules, + graph_meta, + sharded_params, + buffers, + x, + fake_evaluate, + use_fsdp_collectives=True, + use_split_dI_dW=False, + use_multiplexed_graph=True, + ) print("All good!") @@ -232,7 +353,6 @@ def test_split_dI_dW( fake_evaluate: bool = True, use_loss_fn: bool = True, ): - pp_mod, graph_modules, graph_meta = _get_pp_module_and_graphs( model, mesh, @@ -241,48 +361,18 @@ def test_split_dI_dW( use_loss_fn=use_loss_fn, ) sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) - with ( - FakeTensorMode( - allow_non_fake_inputs=True, - shape_env=ShapeEnv(), - ) - if fake_evaluate - else nullcontext() - ): - # # now let's run it - with torch.no_grad(): - fw_args = [*sharded_params, *buffers, *x] - loss_or_output, saved_intermediates = _run_fw_module( - graph_modules.fw, graph_meta, fw_args - ) - tangents = [torch.randn_like(loss_or_output)] - tensors_for_backward, non_tensors_for_backward = saved_intermediates - - bw_args = [ - *non_tensors_for_backward, - *tensors_for_backward, - *tangents, - ] - del ( - tensors_for_backward, - non_tensors_for_backward, - tangents, - saved_intermediates, - ) - assert graph_modules.bw_dI is not None - input_grads, activations_for_backward = _run_dI_bw_module( - graph_modules.bw_dI, graph_meta, bw_args - ) - dw_args = list(activations_for_backward) - del activations_for_backward - assert graph_modules.bw_dW is not None - sharded_param_buffer_grads = _run_dW_bw_module( - graph_modules.bw_dW, graph_meta, dw_args - ) - assert len(sharded_param_buffer_grads) == ( - len(sharded_params) + len(buffers) - ) - + _run_graph_test( + pp_mod, + graph_modules, + graph_meta, + sharded_params, + buffers, + x, + fake_evaluate, + use_fsdp_collectives=False, + use_split_dI_dW=True, + use_multiplexed_graph=True, + ) print("All good!") @@ -294,7 +384,6 @@ def test_combined( fake_evaluate: bool = True, use_loss_fn: bool = True, ): - pp_mod, graph_modules, graph_meta = _get_pp_module_and_graphs( model, mesh, @@ -303,57 +392,18 @@ def test_combined( use_loss_fn=use_loss_fn, ) sharded_params, buffers, x = _get_fw_inputs(pp_mod, eval_input_fn) - with ( - FakeTensorMode( - allow_non_fake_inputs=True, - shape_env=ShapeEnv(), - ) - if fake_evaluate - else nullcontext() - ): - # # now let's run it - with torch.no_grad(): - unshard_args = list(sharded_params) - assert graph_modules.unshard is not None - unsharded_params = _run_unshard_module( - graph_modules.unshard, graph_meta, unshard_args - ) - fw_args = [*unsharded_params, *buffers, *x] - loss_or_output, saved_intermediates = _run_fw_module( - graph_modules.fw, graph_meta, fw_args - ) - tangents = [torch.randn_like(loss_or_output)] - tensors_for_backward, non_tensors_for_backward = saved_intermediates - - bw_args = [ - *non_tensors_for_backward, - *tensors_for_backward, - *tangents, - ] - del ( - tensors_for_backward, - non_tensors_for_backward, - tangents, - saved_intermediates, - ) - assert graph_modules.bw_dI is not None - input_grads, activations_for_backward = _run_dI_bw_module( - graph_modules.bw_dI, graph_meta, bw_args - ) - dw_args = list(activations_for_backward) - del activations_for_backward - assert graph_modules.bw_dW is not None - unsharded_param_buffer_grads = _run_dW_bw_module( - graph_modules.bw_dW, graph_meta, dw_args - ) - unsharded_grads = list(unsharded_param_buffer_grads[: len(sharded_params)]) - del unsharded_param_buffer_grads, input_grads - assert graph_modules.reduce_grad is not None - sharded_grads = _run_reduce_grad_module( - graph_modules.reduce_grad, graph_meta, unsharded_grads - ) - assert len(sharded_grads) == len(sharded_params) - + _run_graph_test( + pp_mod, + graph_modules, + graph_meta, + sharded_params, + buffers, + x, + fake_evaluate, + use_fsdp_collectives=True, + use_split_dI_dW=True, + use_multiplexed_graph=True, + ) print("All good!")