From 18ccbec9cc8de177a3d307fe53e7603cd467757f Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Fri, 10 Oct 2025 08:49:57 -0700 Subject: [PATCH] Pass to split all_gather prologue and reduce_scatter prologue from fsdp graph stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/201, branch: IvanKobzarev/stack/9 --- autoparallel/pipeline/passes.py | 147 ++++++++++++++++++++++++++++++++ tests/test_pipeline_passes.py | 100 ++++++++++++++++++++++ 2 files changed, 247 insertions(+) create mode 100644 autoparallel/pipeline/passes.py create mode 100644 tests/test_pipeline_passes.py diff --git a/autoparallel/pipeline/passes.py b/autoparallel/pipeline/passes.py new file mode 100644 index 0000000..f219199 --- /dev/null +++ b/autoparallel/pipeline/passes.py @@ -0,0 +1,147 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +from contextlib import contextmanager +from functools import partial +from typing import Any + +import torch +import torch.fx.node +import torch.utils._pytree as pytree +from torch._functorch._aot_autograd.descriptors import AOTOutput +from torch._functorch.partitioners import _extract_graph_with_inputs_outputs +from torch._inductor.fx_passes.bucketing import ( + is_all_gather_into_tensor, + is_reduce_scatter_tensor, +) + + +@contextmanager +def exclude_from_fx_side_effectful(exclude_vals: set[Any]): + original_val = torch.fx.node._side_effectful_functions.copy() + try: + torch.fx.node._side_effectful_functions -= exclude_vals + yield + finally: + torch.fx.node._side_effectful_functions.clear() + torch.fx.node._side_effectful_functions.update(original_val) + + +exclude_wait_from_fx_side_effectful = partial( + exclude_from_fx_side_effectful, + { + torch.ops._c10d_functional.wait_tensor, + torch.ops._c10d_functional.wait_tensor.default, + }, +) + + +@dataclasses.dataclass(frozen=True) +class PrefetchOutput(AOTOutput): + pass + + +@dataclasses.dataclass(frozen=True) +class EpilogueInput(AOTOutput): + pass + + +def split_fsdp_prefetch( + g: torch.fx.Graph, stop_at_all_gather: bool = True +) -> tuple[torch.fx.Graph, torch.fx.Graph]: + g_ins = g.find_nodes(op="placeholder") + prefetch_g_outs_map = [] + + for g_in in g_ins: + n = g_in + has_ag = False + while True: + if len(n.users) != 1: + break + user = next(iter(n.users)) + if len(user.all_input_nodes) > 1: + break + n = user + if stop_at_all_gather and is_all_gather_into_tensor(n): + has_ag = True + w_n = next(iter(n.users)) + n = w_n + break + if stop_at_all_gather and not has_ag: + prefetch_g_outs_map.append(g_in) + else: + prefetch_g_outs_map.append(n) + + prefetch_g_outs = prefetch_g_outs_map + prefetch_g_outs_descs: list[AOTOutput] = [ + PrefetchOutput() for _ in range(len(prefetch_g_outs)) + ] + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) + g_outs_descs = pytree.arg_tree_leaves( + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) + ) + with exclude_wait_from_fx_side_effectful(): + prefetch_g = _extract_graph_with_inputs_outputs( + g, + g_ins, + prefetch_g_outs, + prefetch_g_outs_descs, + ) + + main_g = _extract_graph_with_inputs_outputs( + g, + prefetch_g_outs, + g_outs, + g_outs_descs, + ) + return prefetch_g, main_g + + +def split_fsdp_reduce_scatters_epilogue( + g: torch.fx.Graph, +) -> tuple[torch.fx.Graph, torch.fx.Graph]: + g_ins = g.find_nodes(op="placeholder") + g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) + g_outs_descs = pytree.arg_tree_leaves( + next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) + ) + + g_outs_map = [] + for g_out in g_outs: + n = g_out + has_rs = False + while n is not None: + if len(n.all_input_nodes) != 1: + break + n_in = n.all_input_nodes[0] + if len(n_in.users) > 1: + break + prev_n = n + n = n_in + if is_reduce_scatter_tensor(prev_n): + has_rs = True + break + if has_rs: + g_outs_map.append(n) + else: + g_outs_map.append(g_out) + + epi_g_ins = [n for n in g_outs_map if n is not None] + epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] + main_g = _extract_graph_with_inputs_outputs( + g, + g_ins, + epi_g_ins, + epi_g_ins_descs, + ) + epi_g = _extract_graph_with_inputs_outputs( + g, + epi_g_ins, + g_outs, + g_outs_descs, + ) + + return main_g, epi_g diff --git a/tests/test_pipeline_passes.py b/tests/test_pipeline_passes.py new file mode 100644 index 0000000..91ffada --- /dev/null +++ b/tests/test_pipeline_passes.py @@ -0,0 +1,100 @@ +from unittest.mock import patch + +import pytest +import torch +from torch import nn +from torch.fx import GraphModule +from torch.testing._internal.distributed.fake_pg import FakeStore + +from autoparallel.api import AutoParallel +from autoparallel.pipeline.passes import ( + split_fsdp_prefetch, + split_fsdp_reduce_scatters_epilogue, +) + + +@pytest.fixture(scope="module", autouse=True) +def init_pg(): + world_size = 256 + fake_store = FakeStore() + if torch.distributed.is_initialized(): + return + torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size + ) + + +@pytest.fixture(scope="module") +def device_mesh_2d(): + world_size = torch.distributed.get_world_size() + mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (world_size // 8, 8), + mesh_dim_names=( + "dp", + "tp", + ), + ) + return mesh + + +class FFN(nn.Module): + def __init__(self, dim1, dim2): + super().__init__() + bias = False + self.linear1 = nn.Linear(dim1, dim2, bias=bias) + self.linear2 = nn.Linear(dim2, dim1, bias=bias) + + def forward(self, x, y): + return y + 2, self.linear2(self.linear1(x)), y + 2 + + +def _make_model_and_input_fn(mesh, device="cuda"): + bs = 2048 * mesh.shape[0] + dim1 = 1024 + dim2 = 4096 + + def model_fn(): + return FFN(dim1, dim2) + + def input_fn(): + return torch.randn(bs, dim1).to(device), torch.randn(bs, 1).to(device) + + return model_fn, input_fn + + +@patch("torch.cuda.device_count", lambda: 8) +@patch("torch.cuda.get_device_name", lambda device: "H100") +def test_fsdp_split_passes(device_mesh_2d): + low_mem = 0 + high_mem = None + model_fn, input_fn = _make_model_and_input_fn(device_mesh_2d) + with torch.device("meta"): + model = model_fn() + + with AutoParallel(model, input_fn, device_mesh_2d) as autop: + autop.add_parameter_memory_constraint(low=low_mem, high=high_mem) + sharding_placement = autop.optimize_placement() + autop.apply_placement(sharding_placement) + gm = autop.parallel_gm + g = gm.graph + + def gen_g_inputs(g): + phs = g.find_nodes(op="placeholder") + ret = [] + for ph in phs: + ft = ph.meta["val"] + t = torch.randn(ft.shape, dtype=ft.dtype, device=ft.device) + ret.append(t) + return ret + + inputs = gen_g_inputs(g) + g_pro, g_main = split_fsdp_prefetch(g) + g_main, g_epi = split_fsdp_reduce_scatters_epilogue(g_main) + + gm_pro = GraphModule(gm, g_pro) + gm_main = GraphModule(gm, g_main) + gm_epi = GraphModule(gm, g_epi) + + gm(*inputs) + gm_epi(*gm_main(*gm_pro(*inputs)))