From 89afcffdde3ca02265eac6bdce5d9769695aa078 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Mon, 29 Sep 2025 04:04:36 -0700 Subject: [PATCH] [asynctp] Passes, ops stack-info: PR: https://github.com/meta-pytorch/autoparallel/pull/167, branch: IvanKobzarev/stack/1 --- autoparallel/asynctp.py | 1234 +++++++++++++++++++++++++ autoparallel/asynctp_ops.py | 1705 +++++++++++++++++++++++++++++++++++ examples/example_llama3.py | 20 + 3 files changed, 2959 insertions(+) create mode 100644 autoparallel/asynctp.py create mode 100644 autoparallel/asynctp_ops.py diff --git a/autoparallel/asynctp.py b/autoparallel/asynctp.py new file mode 100644 index 00000000..9b708ccd --- /dev/null +++ b/autoparallel/asynctp.py @@ -0,0 +1,1234 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +# Temporary fork of torch/_inductor/fx_passes/micro_pipeline_tp.py +# mypy: allow-untyped-defs +import logging +import operator +from collections import defaultdict +from dataclasses import dataclass, field +from math import prod +from typing import Any, Optional, cast + +import torch +from torch._inductor import inductor_prims +from torch._inductor.pattern_matcher import ( + MULTIPLE, + CallFunction, + Ignored, + KeywordArg, + ListOf, + Match, + PatternExpr, + PatternMatcherPass, +) +from torch.utils._ordered_set import OrderedSet + +import autoparallel.asynctp_ops # noqa: F401 + +log = logging.getLogger(__name__) +aten = torch.ops.aten +patterns = PatternMatcherPass() + +_micro_pipeline_tp_ag_transpose_mm_enabled = True + +# Check performance if overhead of decomposition outweights pipeline wins +_micro_pipeline_tp_ag_mm_last_dim_enabled = False + +_micro_pipeline_tp_mm_rs_last_dim_enabled = True + + +def _is_backward(graph: torch.fx.Graph) -> bool: + placeholders = [] + for node in graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return not all(node.name.startswith("primal") for node in placeholders) + + +def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float: + return M * N * K / (M * K + N * K + M * N) + + +def _filter_nodes_by_target(nodes: list[torch.fx.Node], target) -> list[torch.fx.Node]: + return [x for x in nodes if x.target == target] + + +def _find_ancestors(node: torch.fx.Node) -> OrderedSet[torch.fx.Node]: + ancestors = OrderedSet[torch.fx.Node]() + ancestors.add(node) + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.all_input_nodes: + if inp not in ancestors: + ancestors.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return OrderedSet(node for node in ancestors if node.op != "placeholder") + + +def _get_tensor(node: torch.fx.Node) -> torch.Tensor: + val = node.meta["val"] + assert isinstance(val, torch.Tensor) + return val + + +@dataclass +class _AllGatherMatch: + match: Match + shard_node: torch.fx.Node + ag_node: torch.fx.Node + res_node: torch.fx.Node + gather_dim: int + group_name: str + + def replace_with(self, new_node: torch.fx.Node) -> None: + self.res_node.replace_all_uses_with(new_node) + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_all_gather_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def make_zero_dim_all_gather_pattern(shard): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.all_gather_into_tensor.default, + shard, + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.all_gather_tensor with gather_dim == 0 + zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard")) + + def make_all_gather_split_pattern(shard): + return CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + make_zero_dim_all_gather_pattern(shard), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ) + + def make_cat_pattern(splits): + return CallFunction( + aten.cat.default, + ListOf(splits), + KeywordArg("gather_dim"), + ) + + # Matches funcol.all_gather_tensor with gather_dim > 0 + non_zero_dim_all_gather_pattern = make_cat_pattern( + make_all_gather_split_pattern(KeywordArg("shard")), + ) + + # Match a zero-dim all-gather in which the data is transferred as uint8 and + # viewed back as the original dtype. + zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_zero_dim_all_gather_pattern( + KeywordArg("shard"), + ), + Ignored(), + ) + + # Match a non-zero dim all-gather in which the data is transferred as uint8 + # and viewed back as the original dtype. + non_zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_cat_pattern( + CallFunction( + aten.view.dtype, + make_all_gather_split_pattern( + KeywordArg("shard"), + ), + Ignored(), + ), + ), + Ignored(), + ) + + # If two patterns with the same res_node_target have the same suffix, the + # longer pattern should appear first in the list. + # e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1) + # should appear before (2) in the list. + res_node_target_to_patterns = { + aten.cat.default: [ + (non_zero_dim_all_gather_pattern, 0), + ], + aten.view.dtype: [ + (non_zero_dim_type_erased_all_gather_pattern, 0), + (zero_dim_type_erased_all_gather_pattern, 0), + ], + c10d.wait_tensor.default: [ + (zero_dim_all_gather_pattern, 0), + ], + } + + # Match in reverse to ensure longer patterns is prioritized + all_gathers = [] + visited_ag_nodes = OrderedSet[torch.fx.Node]() + for node in reversed(graph.nodes): + for target, patterns in res_node_target_to_patterns.items(): + if node.target != target: + continue + for pattern, ag_node_idx in patterns: + match = pattern.match(node) + if not match: + continue + + assert isinstance(match, Match) + ag_node = match.nodes[ag_node_idx] + assert ag_node.target == c10d.all_gather_into_tensor.default + + if ag_node in visited_ag_nodes: + continue + visited_ag_nodes.add(ag_node) + + ag_match = _AllGatherMatch( + match=match, + shard_node=match.kwargs["shard"], + ag_node=ag_node, + res_node=node, + gather_dim=match.kwargs.get("gather_dim", 0), + group_name=match.kwargs["group_name"], + ) + all_gathers.append(ag_match) + + return list(reversed(all_gathers)) + + +@dataclass +class _ReduceScatterMatch: + match: Match + input_node: torch.fx.Node + reduce_scatter_node: torch.fx.Node + wait_tensor_node: torch.fx.Node + reduce_op: str + scatter_dim: int + group_name: str + + def replace_with(self, new_node: torch.fx.Node) -> None: + # Replace all uses of the result node (wait_tensor) with the fused node. + self.wait_tensor_node.replace_all_uses_with(new_node) + + # If the reduce-scatter result is saved for backward, save the fused node for backward instead. + self._update_save_for_backward(new_node) + + def _update_save_for_backward(self, new_node: torch.fx.Node) -> None: + """ + If the output node is a user of the reduce_scatter node (indicating the reduce_scatter + result is saved for backward), this method will update the output node to use the fused node instead. + """ + output_node = None + for user in self.reduce_scatter_node.users: + if user.target == "output": + output_node = user + break + if output_node is not None: + output_node.replace_input_with(self.reduce_scatter_node, new_node) + + # Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not + # saved for backward anymore. + assert ( + len(self.reduce_scatter_node.users) == 1 + ), "Reduce scatter node has multiple users, this is not expected" + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_reduce_scatter_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def reduce_scatter_template(inp: PatternExpr, users: int): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.reduce_scatter_tensor.default, + inp, + KeywordArg("reduce_op"), + Ignored(), + KeywordArg("group_name"), + _users=users, + ), + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim == 0 + zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + KeywordArg("input"), users=1 + ) + + # Two users will occur when the reduce-scatter result is saved for backward + zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( + KeywordArg("input"), users=2 + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim > 0 + non_zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + users=1, + ) + + # Two users will occur when the reduce-scatter result is saved for backward + non_zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + users=2, + ) + + reduce_scatters = [] + for node in reversed(graph.nodes): + if node.target == c10d.wait_tensor.default: + if match := non_zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[-2], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[0], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + elif match := non_zero_dim_reduce_scatter_pattern_multi_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[-2], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern_multi_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[0], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + return list(reversed(reduce_scatters)) + + +@dataclass +class _Matmul: + nodes: list[torch.fx.Node] + arg_ancestor_nodes: OrderedSet[torch.fx.Node] = field(init=False) + A_node: torch.fx.Node + B_node: torch.fx.Node + pre_mm_reshape: Optional[torch.fx.Node] + post_mm_reshape: Optional[torch.fx.Node] + pre_mm_B_transpose: Optional[torch.fx.Node] + + def __post_init__(self): + assert len(self.nodes) in (1, 2, 3) + if len(self.nodes) == 1: + assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default) + self.arg_ancestor_nodes = _find_ancestors(self.B_node) + elif len(self.nodes) == 2: + assert self.nodes[0].target == aten.permute.default + assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default) + self.arg_ancestor_nodes = _find_ancestors(self.A_node) + else: + assert self.nodes[0].target == aten.reshape.default + assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default) + assert self.nodes[2].target == aten.reshape.default + self.arg_ancestor_nodes = _find_ancestors(self.B_node) + + def replace_with(self, new_node: torch.fx.Node) -> None: + """ + Replace the matmul with the new node. + """ + graph = new_node.graph + + # For 2D-matmuls, we simply replace the mm node with `new_node`. + if len(self.nodes) == 1: + mm_node = self.nodes[0] + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + mm_node.replace_all_uses_with(new_node) + graph.erase_node(mm_node) + return + + if len(self.nodes) == 2: + permute_node = self.nodes[0] + mm_node = self.nodes[1] + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + mm_node.replace_all_uses_with(new_node) + graph.erase_node(mm_node) + if len(permute_node.users) == 0: + graph.erase_node(permute_node) + return + # An ND-matmul is reshape -> mm -> reshape sequence. We first replace + # the second reshape node with `new_node`. Then, we ensure that the + # original mm node in the sequence ends up with zero users by replacing + # it with a reverse reshape of `new_node`. + graph = new_node.graph + assert len(self.nodes) == 3 + mm_node = self.nodes[1] + output_reshape_node = self.nodes[2] + + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + assert output_reshape_node.target == aten.reshape.default + + output_reshape_node.replace_all_uses_with(new_node) + if len(mm_node.users) > 1: + with graph.inserting_after(new_node): + new_mm_node = graph.call_function( + aten.reshape.default, + args=(new_node, list(_get_tensor(mm_node).shape)), + ) + mm_node.replace_all_uses_with(new_mm_node) + + def erase(self) -> None: + for node in reversed(self.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + @classmethod + def from_match(cls, match: list[torch.fx.Node]) -> "_Matmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten.mm.default, + aten.reshape.default, + ) + mm_node = match[0] if len(match) == 1 else match[1] + return _Matmul( + nodes=match, + A_node=cast("torch.fx.Node", match[0].args[0]), + B_node=cast("torch.fx.Node", mm_node.args[1]), + # _Matmul handles reshapes via custom graph manipulation logic, see `replace_with()` method. + # TODO: explore unifying the _Matmul and _ScaledMatmul approaches to handling reshapes. + pre_mm_reshape=None, + post_mm_reshape=None, + pre_mm_B_transpose=None, + ) + + +@dataclass +class _ScaledMatmul(_Matmul): + A_scale_node: torch.fx.Node + B_scale_node: torch.fx.Node + bias_node: Optional[torch.fx.Node] + result_scale_node: Optional[torch.fx.Node] + out_dtype: Optional[torch.dtype] + use_fast_accum: bool + pre_mm_reshape: Optional[torch.fx.Node] + post_mm_reshape: Optional[torch.fx.Node] + + def __post_init__(self): + super().__post_init__() + self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node) + self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node) + + @classmethod + def from_match(cls, match: list[torch.fx.Node]) -> "_ScaledMatmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten._scaled_mm.default, + aten.reshape.default, + ) + + def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any: + if idx >= len(node.args): + return default + return node.args[idx] + + # Use mm_node with 2D args for both A and B, even if this is a "reshape -> mm -> reshape" pattern. + # We will store the reshapes in pre_mm_reshape and post_mm_reshape, to be referenced later to + # produce the correct output shapes, reduce-scatter along the correct dimensions, etc. + is_reshape_mm_reshape_pattern = match[0].target == aten.reshape.default + mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0] + pre_mm_reshape = match[0] if is_reshape_mm_reshape_pattern else None + post_mm_reshape = match[-1] if is_reshape_mm_reshape_pattern else None + A_node = cast("torch.fx.Node", mm_node.args[0]) + B_node = cast("torch.fx.Node", mm_node.args[1]) + A_scale_node = cast("torch.fx.Node", mm_node.args[2]) + B_scale_node = cast("torch.fx.Node", mm_node.args[3]) + + return _ScaledMatmul( + nodes=match, + A_node=A_node, + B_node=B_node, + A_scale_node=A_scale_node, + B_scale_node=B_scale_node, + bias_node=get_arg(mm_node, 4, None), + result_scale_node=get_arg(mm_node, 5, None), + out_dtype=get_arg(mm_node, 6, None), + use_fast_accum=get_arg(mm_node, 7, False), + pre_mm_reshape=pre_mm_reshape, + post_mm_reshape=post_mm_reshape, + pre_mm_B_transpose=None, + ) + + +def _find_reshape_mm_reshape(node: torch.fx.Node) -> list[_Matmul]: + if node.target != aten.reshape.default: + return [] + + matches = [] + for mm_node in node.users: + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + continue + for reshape_node in mm_node.users: + if reshape_node.target != aten.reshape.default: + continue + + # Since the reshape -> mm -> reshape pattern would be subsumed into + # the fused op, we only match the patterns where the shape of the + # second reshape is matches the mm result produced by the fused op. + matmul_input_node = cast("torch.fx.Node", node.args[0]) + B_node = cast("torch.fx.Node", mm_node.args[1]) + matmul_out_shape = torch.Size( + [ + *_get_tensor(matmul_input_node).shape[:-1], + _get_tensor(B_node).shape[-1], + ] + ) + if _get_tensor(reshape_node).shape != matmul_out_shape: + continue + matches.append([node, mm_node, reshape_node]) + # If for some rare reason mm_node is being reshaped by two + # different reshape nodes, we only include mm_node once in the + # parsing result. + break + + matmuls = [] + for match in matches: + mm_node = match[1] + if mm_node.target == aten.mm.default: + matmul = _Matmul.from_match(match) + matmuls.append(matmul) + elif mm_node.target == aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match(match) + matmuls.append(matmul) + else: + raise AssertionError( + "Expect the node's target to be either aten.mm.default or " + f"aten._scaled_mm.default. Got {mm_node.target}." + ) + return matmuls + + +def _find_permute_mm(permute_node: torch.fx.Node) -> list[_Matmul]: + for mm_node in permute_node.users: + if mm_node.target != aten.mm.default: + continue + if permute_node == mm_node.args[1]: + return [ + _Matmul( + nodes=[permute_node, mm_node], + A_node=cast("torch.fx.Node", mm_node.args[0]), + B_node=cast("torch.fx.Node", mm_node.args[1]), + pre_mm_reshape=None, + post_mm_reshape=None, + pre_mm_B_transpose=permute_node, + ) + ] + return [] + + +def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]: + """ + Find the matmuls that use `node` as the lhs argument. + """ + matmuls = [] + for user in node.users: + # ND matmuls + if user.target == aten.reshape.default: + matmuls.extend(_find_reshape_mm_reshape(user)) + # 2D matmuls + elif user.target == aten.mm.default: + matmul = _Matmul.from_match(match=[user]) + matmuls.append(matmul) + elif user.target == aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match([user]) + matmuls.append(matmul) + elif ( + _micro_pipeline_tp_ag_transpose_mm_enabled + and user.target == aten.permute.default + and (user.args[1] == [1, 0] or user.args[1] == [0, 1]) + ): + permute_matmuls = _find_permute_mm(user) + if permute_matmuls: + if not matmuls: + matmuls.extend(permute_matmuls) + else: + has_not_permute_matmul = False + for matmul in matmuls: + if matmul.pre_mm_B_transpose: + has_not_permute_matmul = True + if not has_not_permute_matmul: + matmuls.extend(permute_matmuls) + + return matmuls + + +def _insert_fused_all_gather_matmul( + graph: torch.fx.Graph, + matmuls: list[_Matmul], + shard_node: torch.fx.Node, + gather_dim: int, + group_name: str, +) -> torch.fx.Node: + mm_types = OrderedSet(map(type, matmuls)) + assert len(mm_types) == 1 + mm_type = next(iter(mm_types)) + if mm_type == _Matmul: + B_nodes = [matmul.B_node for matmul in matmuls] + return graph.call_function( + torch.ops._ap_symm_mem.fused_all_gather_matmul.default, + args=(shard_node, B_nodes, gather_dim, group_name), + kwargs={"return_A": True}, + ) + elif mm_type == _ScaledMatmul: + scaled_matmuls = cast("list[_ScaledMatmul]", matmuls) + return graph.call_function( + torch.ops._ap_symm_mem.fused_all_gather_scaled_matmul.default, + args=( + shard_node, + [matmul.B_node for matmul in scaled_matmuls], + scaled_matmuls[0].A_scale_node, + [matmul.B_scale_node for matmul in scaled_matmuls], + gather_dim, + group_name, + [matmul.bias_node for matmul in scaled_matmuls], + [matmul.result_scale_node for matmul in scaled_matmuls], + [matmul.out_dtype for matmul in scaled_matmuls], + [matmul.use_fast_accum for matmul in scaled_matmuls], + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {mm_type}") + + +def graph_call_function_transpose(graph, n): + return graph.call_function( + torch.ops.aten.permute.default, + args=(n, [1, 0]), + ) + + +def graph_call_function_contiguous(graph, n): + return graph.call_function( + torch.ops.aten.clone.default, + args=(n,), + kwargs={"memory_format": torch.contiguous_format}, + ) + + +# mat_B = ag(shard) +# mat_B_t = mat_B.t() +# return mm(mat_A, mat_B_t) +# -> +# mat_A_t = mat_A.t() +# mat_B = ag(shard) +# res_mm_t = mm(mat_B, mat_A_t) +# return res_mm_t +def _insert_fused_all_gather_transpose_matmul( + graph: torch.fx.Graph, + matmuls: list[_Matmul], + shard_node: torch.fx.Node, + gather_dim: int, + group_name: str, +) -> torch.fx.Node: + mm_types = OrderedSet(map(type, matmuls)) + assert len(mm_types) == 1 + mm_type = next(iter(mm_types)) + if mm_type == _Matmul: + B_nodes = [ + graph_call_function_transpose(graph, matmul.A_node) for matmul in matmuls + ] + + res_fused_mm = graph.call_function( + torch.ops._ap_symm_mem.fused_all_gather_matmul.default, + args=(shard_node, B_nodes, gather_dim, group_name), + kwargs={"return_A": True}, + ) + + return res_fused_mm + else: + raise AssertionError(f"Unexpected matmul match type: {mm_type}") + + +def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: + """ + Fused the pattern + + A = all_gather_tensor(A_shard, gather_dim, group_name) + C_0 = torch.matmul(A, B_0) + C_1 = torch.matmul(A, B_1) + C_2 = torch.matmul(A, B_2) + ... + + into + + A, Cs = torch.ops._ap_symm_mem.fused_all_gather_matmul( + A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_shard_for_fused_all_gather_matmul, + ) + + shard_node, ag_node, ag_res_node, gather_dim, group_name = ( + all_gather.shard_node, + all_gather.ag_node, + all_gather.res_node, + all_gather.gather_dim, + all_gather.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + if ( + not _micro_pipeline_tp_ag_mm_last_dim_enabled + and gather_dim == _get_tensor(shard_node).ndim - 1 + ): + return + + # Find consumer matmuls + matmuls = _find_consumer_matmuls(ag_res_node) + + # The matmuls are only fusible if non-A args don't depend on the all-gather + # result node + matmuls = [ + matmul + for matmul in matmuls + if all_gather.res_node not in matmul.arg_ancestor_nodes + ] + + if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1: + return + + # Fuse the all_gather_tensor with the eligible matmuls + graph = ag_node.graph + with graph.inserting_before(ag_node): + if matmuls[0].pre_mm_B_transpose is not None: + if "val" in shard_node.meta: + restrided = restride_A_shard_for_fused_all_gather_matmul( + _get_tensor(shard_node), + gather_dim, + ) + shard_node = graph.call_function( + inductor_prims.force_stride_order, + args=(shard_node, restrided.stride()), + ) + fused_node = _insert_fused_all_gather_transpose_matmul( + graph, matmuls, shard_node, gather_dim, group_name + ) + new_ag_node = graph.call_function( + operator.getitem, + args=(fused_node, 0), + ) + fused_out_1 = graph.call_function( + operator.getitem, + args=(fused_node, 1), + ) + new_out_nodes = [ + graph.call_function(operator.getitem, args=(fused_out_1, i)) + for i in range(len(matmuls)) + ] + new_out_nodes = [ + graph_call_function_transpose(graph, out_node_t) + for out_node_t in new_out_nodes + ] + # Restride the inputs before fused that result will be contiguous (or pre pass stridenss) + new_out_nodes = [ + graph_call_function_contiguous(graph, out_node) + for out_node in new_out_nodes + ] + for matmul, new_out_node in zip(matmuls, new_out_nodes): + matmul.replace_with(new_out_node) + matmul.erase() + else: + if "val" in shard_node.meta: + restrided = restride_A_shard_for_fused_all_gather_matmul( + _get_tensor(shard_node), + gather_dim, + ) + shard_node = graph.call_function( + inductor_prims.force_stride_order, + args=(shard_node, restrided.stride()), + ) + fused_node = _insert_fused_all_gather_matmul( + graph, matmuls, shard_node, gather_dim, group_name + ) + new_ag_node = graph.call_function( + operator.getitem, + args=(fused_node, 0), + ) + new_out_nodes_node = graph.call_function( + operator.getitem, + args=(fused_node, 1), + ) + + for idx, matmul in enumerate(matmuls): + new_out_node = graph.call_function( + operator.getitem, + args=(new_out_nodes_node, idx), + ) + matmul.replace_with(new_out_node) + matmul.erase() + + all_gather.replace_with(new_ag_node) + all_gather.erase() + + # If the new_ag_node has no users, we tell the fused op to not return + # it. This creates more optimization opportunities. + if len(new_ag_node.users) == 0: + graph.erase_node(new_ag_node) + kwargs = dict(fused_node.kwargs) + if "return_A" in kwargs: + kwargs["return_A"] = False + fused_node.kwargs = kwargs + + # Raise ancestors of non-A args that are topologically ordered between + # ag_res_node and the matmul above fused_node. + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + OrderedSet(x for matmul in matmuls for x in matmul.arg_ancestor_nodes), + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + +def _scatter_dim_after_reshape( + reshape_node: torch.fx.Node, orig_scatter_dim: int +) -> int: + """ + Given a reshape node and the original scatter dim for the target tensor, + returns the new scatter dim for the reshaped tensor. + """ + # if there was no pre-mm reshape, scatter dim will not change. + if not reshape_node: + return orig_scatter_dim + + reshape_op_output_tensor = _get_tensor(reshape_node) + assert ( + reshape_op_output_tensor.ndim == 2 + ), "reshape must produce 2D tensor for scaled_mm" + + assert len(reshape_node.args) >= 1, "reshape node must have at least 1 arg" + input_tensor_node = cast(torch.fx.Node, reshape_node.args[0]) + reshape_op_input_tensor = _get_tensor(input_tensor_node) + assert ( + reshape_op_input_tensor.ndim > reshape_op_output_tensor.ndim + ), "reshape must be from 3D+ to 2D" + + # Note: for a N-D tensor to be reshaped into 2D, either the leading dims or ending dims must + # be collapsed to a single dim. First determine which of these happened. + input_shape = reshape_op_input_tensor.shape + output_shape = reshape_op_output_tensor.shape + leading_dims_collapsed = output_shape[0] == prod(input_shape[:-1]) + + # Case 1: scatter dim 0 always maps to 0 after any reshape from 3D+ to 2D, regardless if + # leading dims or ending dims were collapsed. + if orig_scatter_dim == 0: + return 0 + + # Case 2: scatter dim "ndim-1" always maps to 1 after any reshape from 3D+ to 2D, regardless if + # leading dims or ending dims were collapsed. + if orig_scatter_dim == reshape_op_input_tensor.ndim - 1: + return 1 + + # Case 3: scatter dim was one of the middle dims (between 0 and ndim-1). + # if the leading dims were collapsed, the new scatter dim will be 0. + # if the ending dims were collapsed, the new scatter dim will be 1. + return 0 if leading_dims_collapsed else 1 + + +def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]: + """ + Returns producer matmul node if found, otherwise returns None. + """ + if node.target == aten.mm.default: + return _Matmul.from_match(match=[node]) + elif node.target == aten._scaled_mm.default: + return _ScaledMatmul.from_match(match=[node]) + elif node.target == aten.reshape.default: + reshape_node_1 = node + + mm_node = reshape_node_1.args[0] + assert isinstance(mm_node, torch.fx.Node) + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + return None + + reshape_node_0 = mm_node.args[0] + assert isinstance(reshape_node_0, torch.fx.Node) + if reshape_node_0.target != aten.reshape.default: + return None + + if mm_node.target == aten.mm.default: + return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1]) + elif mm_node.target == aten._scaled_mm.default: + return _ScaledMatmul.from_match( + match=[reshape_node_0, mm_node, reshape_node_1] + ) + return None + + +def _insert_fused_matmul_reduce_scatter( + graph: torch.fx.Graph, + matmul: _Matmul, + reduce_op: str, + orig_scatter_dim: int, + group_name: str, + scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern + output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern +) -> torch.fx.Node: + if type(matmul) == _Matmul: # noqa: E721 + return graph.call_function( + torch.ops._ap_symm_mem.fused_matmul_reduce_scatter.default, + args=( + matmul.A_node, + matmul.B_node, + reduce_op, + orig_scatter_dim, + group_name, + ), + ) + elif type(matmul) == _ScaledMatmul: # noqa: E721 + return graph.call_function( + torch.ops._ap_symm_mem.fused_scaled_matmul_reduce_scatter.default, + args=( + matmul.A_node, + matmul.B_node, + matmul.A_scale_node, + matmul.B_scale_node, + reduce_op, + orig_scatter_dim, + scatter_dim_after_reshape, + group_name, + output_shape, + matmul.bias_node, + matmul.result_scale_node, + matmul.out_dtype, + matmul.use_fast_accum, + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") + + +def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: + """ + Fused the pattern + + reduce_scatter_tensor(A @ B, scatter_dim, group_name) + + into + + torch.ops._ap_symm_mem.fused_matmul_reduce_scatter( + A, B, scatter_dim, group_name, + ) + + Returns boolean indicating if fusion was successful or not. + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_for_fused_matmul_reduce_scatter, + ) + + ( + input_node, + _reduce_scatter_node, # noqa: F841 + rs_wait_tensor_node, + reduce_op, + orig_scatter_dim, + group_name, + ) = ( + reduce_scatter.input_node, + reduce_scatter.reduce_scatter_node, + reduce_scatter.wait_tensor_node, + reduce_scatter.reduce_op, + reduce_scatter.scatter_dim, + reduce_scatter.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + if ( + not _micro_pipeline_tp_mm_rs_last_dim_enabled + and orig_scatter_dim == _get_tensor(input_node).ndim - 1 + ): + return + + # Currently fused_matmul_reduce_scatter doesn't return the matmul result, + # so we can't apply the fusion if the matmul result is used by multiple + # users. This is not a fundamental limitation of the fused op and can be + # addressed if needed. + if len(input_node.users) != 1: + log.warning( + "matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion." + ) + return + + matmul = _find_producer_matmul(input_node) + if matmul is None: + log.warning( + "no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion" + ) + return + + if rs_wait_tensor_node in matmul.arg_ancestor_nodes: + log.warning( + "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion" + ) + return + + # We need to track 3 values for the fused scaled mm reduce scatter implementation: + # 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims. + # 2. The scatter dim after the reshape, to use when we are doing the 2D (a*b,c) @ (c,d) = (a,b,d) scaled mm op. + # 3. Store expected potentially 3D+ mm output shape, so we can reshape the 2D mm output to the intended + # 3D+ shape before applying reduce-scatter, and to prevent shape errors with subsequent ops. + + # If 'A' was reshaped from 3D+ -> 2D for the mm, we need to determine the new scattter dim after the reshape + # for the fused matmul reduce scatter implementation to use. + if matmul.pre_mm_reshape: + scatter_dim_after_maybe_reshape = _scatter_dim_after_reshape( + matmul.pre_mm_reshape, orig_scatter_dim + ) + else: + scatter_dim_after_maybe_reshape = orig_scatter_dim + + # If the 2D mm output was reshaped from 2D -> 3D+, we need to store the intended output shape for the + # fused matmul reduce scatter implementation to use. + if matmul.post_mm_reshape: + output_shape = list(_get_tensor(matmul.post_mm_reshape).shape) + else: + A_orig_shape = list(_get_tensor(matmul.A_node).shape) + B_shape = list(_get_tensor(matmul.B_node).shape) + output_shape = [*A_orig_shape[:-1], B_shape[-1]] + + graph = rs_wait_tensor_node.graph + with graph.inserting_before(rs_wait_tensor_node): + # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter + if "val" in matmul.A_node.meta: + restrided = restride_A_for_fused_matmul_reduce_scatter( + _get_tensor(matmul.A_node), + scatter_dim_after_maybe_reshape, + ) + matmul.A_node = graph.call_function( + inductor_prims.force_stride_order, + args=(matmul.A_node, restrided.stride()), + ) + + # Replace matched subgraph with fused matmul reduce scatter node + fused_node = _insert_fused_matmul_reduce_scatter( + graph, + matmul, + reduce_op, + orig_scatter_dim, + group_name, + scatter_dim_after_maybe_reshape, + output_shape, + ) + reduce_scatter.replace_with(fused_node) + reduce_scatter.erase() + matmul.erase() + + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + matmul.arg_ancestor_nodes, + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + log.debug("successfully fused matmul reduce scatter") + + +def _get_node_to_ancestors( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: + """ + Compute the ancestors for all nodes in a graph. + """ + node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated] + for node in graph.nodes: + node_to_ancestors[node] = OrderedSet(node.all_input_nodes) + for dep in node.all_input_nodes: + node_to_ancestors[node] |= node_to_ancestors[dep] + + return node_to_ancestors + + +def _get_collective_to_overlappable_nodes( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, list[torch.fx.Node]]: + """ + For each collective in the graph, find nodes that are neither ancestors nor + descendants of the collective. + """ + + def is_collective(node) -> bool: + # Only consider all-gather and reduce-scatter in the context of + # micro-pipeline TP. + return node.target in [ + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ] + + node_to_ancestors = _get_node_to_ancestors(graph) + collective_to_overlappable_nodes = defaultdict(list) + for node in graph.nodes: + if not is_collective(node): + continue + for x in graph.nodes: + if ( + node not in node_to_ancestors[x] + and x not in node_to_ancestors[node] + and x.op == "call_function" + ): + collective_to_overlappable_nodes[node].append(x) + + return collective_to_overlappable_nodes + + +def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]: + """ + Find all unexposed collectives in the graph. + + Because we don't have the runtime estimate, this function is a rough + estimation using the following strong/hand-wavy assumptions: + + - Only a predefined set of "compute intensive" operation can hide a collective. + - Any "compute intensive" operation can hide exactly one collective. + """ + + def _is_compute_intensive(node: torch.fx.Node) -> bool: + return node.target in [torch.ops.aten.mm.default] + + collective_to_overlapping_candidates = defaultdict(list) + available_nodes = OrderedSet[torch.fx.Node]() + collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph) + for collective, overlappable_nodes in collective_to_overlappable_nodes.items(): + candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)] + collective_to_overlapping_candidates[collective] = candidates + available_nodes.update(candidates) + + unexposed_collectives = [] + for ( + collective, + overlapping_candidates, + ) in collective_to_overlapping_candidates.items(): + # Each collective consumes exactly one overlapping candidate + for x in overlapping_candidates: + if x in available_nodes: + unexposed_collectives.append(collective) + available_nodes.remove(x) + break + return unexposed_collectives + + +def micro_pipeline_tp_pass(graph: torch.fx.Graph): + all_gathers = find_all_gather_patterns(graph) + reduce_scatters = find_reduce_scatter_patterns(graph) + + # When a collective can be hidden through either simple overlapping or + # micro-pipeline TP, we prefer simple overlapping to avoid the overhead + # associated with decomposition. + unexposed_collectives = _get_unexposed_collectives(graph) + all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] + reduce_scatters = [ + x for x in reduce_scatters if x.reduce_scatter_node not in unexposed_collectives + ] + + if not all_gathers and not reduce_scatters: + log.warning( + "async TP found no matching all-gather/reduce-scatter patterns for fusion" + ) + + for reduce_scatter in reduce_scatters: + fuse_matmul_reduce_scatter(reduce_scatter) + + for all_gather in all_gathers: + fuse_all_gather_matmul(all_gather) diff --git a/autoparallel/asynctp_ops.py b/autoparallel/asynctp_ops.py new file mode 100644 index 00000000..7d9a6ab6 --- /dev/null +++ b/autoparallel/asynctp_ops.py @@ -0,0 +1,1705 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +# Temporary fork of torch/distributed/_symmetric_memory/__init__.py +from __future__ import annotations + +import math +import os +from datetime import timedelta +from enum import Enum +from functools import partial +from typing import Any, Callable + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch._C._autograd import DeviceType +from torch.distributed._distributed_c10d import Work as _Work +from torch.distributed._distributed_c10d import _register_work, _SymmetricMemory +from torch.distributed._symmetric_memory import get_symm_mem_workspace, rendezvous + +_is_test_mode: bool = False +_mocked_group_names: set[str] | None = None +_backend_streams: dict[int, torch.cuda.Stream] = {} + + +def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream: + if priority not in _backend_streams: + _backend_streams[priority] = torch.cuda.Stream(priority=priority) + return _backend_streams[priority] + + +def _pipelined_multi_all_gather_and_consume( + shard: list[torch.Tensor], + shard_consumer: Callable[[list[torch.Tensor], int], None], + ag_out: list[torch.Tensor], + group_name: str, + ag_out_needed: bool = True, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + gathered = [ + all_gather_tensor(x, gather_dim=0, group=group) + for x in shard + ] + + shards = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + p2p_workspace_size_req = 0 + for x in shard: + p2p_workspace_size_req += x.numel() * x.element_size() + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + for x, y in zip(shard, ag_out): + assert x.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `shard` must be contiguous" + ) + assert y.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `ag_out` must be contiguous" + ) + assert x.shape[0] * group_size == y.shape[0] + assert x.shape[1:] == y.shape[1:] + + def copy_shard(dst: list[torch.Tensor], src: list[torch.Tensor]) -> None: + for d, s in zip(dst, src): + d.copy_(s) + + def get_p2p_bufs(remote_rank: int) -> list[torch.Tensor]: + offset_bytes = 0 + bufs = [] + for x in shard: + buf = symm_mem.get_buffer( + remote_rank, + x.shape, + x.dtype, + storage_offset=offset_bytes // x.element_size(), + ) + bufs.append(buf) + offset_bytes += buf.numel() * buf.element_size() + return bufs + + local_p2p_bufs = get_p2p_bufs(rank) + + # shards[i] => shard from rank i + shards: list[list[torch.Tensor]] = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + # Parallelization strategy: after each rank copies its shard into its local + # p2p buffer, every rank issues independent p2p copy -> shard_consumer + # sequences to two streams. In addition to computation/communication + # overlapping, the strategy allows for computation/computation overlapping, + # greatly reducing quantization inefficiency. + # + # Notation: + # - "mv" for the copy to local buffer + # - "cp" for p2p copies + # - "b" for barriers + # + # Constraints: + # - The GPU scheduler may or may not overlap "mv" with the first shard_consumer. + # - "cp" from different streams cannot overlap. + # + # Ideal scenario 0 - "mv" overlaps with the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Ideal scenario 1 - "mv" is scheduled before the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "mv" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "b" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ] [b][ cp ][ shard_consumer ] + # + # We haven't yet figured out a way to ensure "mv" and "b" are either + # overlapped with or scheduled before the first shard_consumer. Thus, to + # prevent suboptimal scenarios, we are giving up the chance to overlap "mv" + # and "b" with the first shard_consumer for now. + copy_shard(dst=local_p2p_bufs, src=shard) + symm_mem.barrier(channel=1) + backend_stream.wait_stream(torch.cuda.current_stream()) + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + shard_consumer(shard, rank) + + for step in range(1, group_size): + if step % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + remote_rank = (step + rank) % group_size + remote_p2p_bufs = get_p2p_bufs(remote_rank) + with stream: + copy_shard(dst=shards[remote_rank], src=remote_p2p_bufs) + shard_consumer(shards[remote_rank], remote_rank) + + if ag_out_needed: + # Copy from input to the all-gather output. Opportunistically overlap + # it with the last shard_consumer. + if group_size % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + with stream: + copy_shard(dst=shards[rank], src=shard) + + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +def _pipelined_all_gather_and_consume( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group_name: str, + ag_out_needed: bool = True, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + ag_out = all_gather_tensor(shard, gather_dim=0, group=group) + shards = ag_out.chunk(group.size()) + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + + def adapter(shard: list[torch.Tensor], rank: int) -> None: + shard_consumer(shard[0], rank) + + _pipelined_multi_all_gather_and_consume( + [shard], + adapter, + [ag_out], + group_name, + ag_out_needed, + ) + + +def _pipelined_produce_and_all2all( + chunk_producer: Callable[[int, torch.Tensor], None], + output: torch.Tensor, + group_name: str, + out_chunk_dim=0, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + chunks = [ + chunk_producer(dst_rank, chunks[dst_rank]) + for dst_rank in range(group_size): + ] + dist.all_to_all_single(output=output, input=torch.cat(chunks)) + """ + out_chunks = output.chunk( + c10d._get_group_size_by_name(group_name), dim=out_chunk_dim + ) + p2p_workspace_size_req = out_chunks[0].numel() * out_chunks[0].element_size() * 2 + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: + assert idx in (0, 1) + offset = 0 if idx == 0 else out_chunks[0].numel() + return symm_mem.get_buffer( + rank, out_chunks[0].shape, out_chunks[0].dtype, offset + ) + + # Prepare two local p2p buffers, so that a remote rank can pull the result + # of step [i] in one p2p buffer while the local rank can compute the + # result of step [i+1] and write it directly the other p2p buffer. + local_p2p_buf_0 = get_p2p_buf(rank, 0) + local_p2p_buf_1 = get_p2p_buf(rank, 1) + + for step in range(1, group_size): + remote_rank = (rank - step) % group_size + if step % 2 == 0: + stream = torch.cuda.current_stream() + p2p_buf = local_p2p_buf_1 + remote_p2p_buf = get_p2p_buf(remote_rank, 1) + else: + stream = backend_stream + p2p_buf = local_p2p_buf_0 + remote_p2p_buf = get_p2p_buf(remote_rank, 0) + with stream: + # Parallelization strategy: every rank issues independent compute + # -> barrier -> p2p copy sequences on two streams. In addition to + # computation/communication overlapping, the strategy allows for + # computation/computation overlapping, greatly reducing + # quantization inefficiency. + # + # Ideally, stream activities would look like this ("b" for + # barriers, "cp" for p2p copies): + # + # [rank 0] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # Note that the barriers synchronize streams with the same ID + # across ranks. They don't synchronize streams on the same rank. + # + # Since the work on both streams is independent, there's no + # guarantee that the chunk_producer from stream 0 or stream 1 will + # be scheduled first. If there is a scheduling mismatch across + # ranks, the barrier forces all ranks to wait for the slowest. + # + # When scheduling mismatches occur among ranks, the stream + # activities might look like this (note that p2p copies from + # different streams cannot overlap with each other): + # + # [rank 0] + # stream 0: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # stream 1: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # stream 1: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # + # To prevent this, we need to ensure that the chunk_producer on + # stream 1 gets scheduled first on every rank. Without access to + # the underlying kernels, CUDA offers no API to control the + # scheduling order of two independent, overlapping kernels. Our + # solution is to issue a small sleep kernel in stream 0. The sleep + # duration is insignificant, but having an extra task in stream 0 + # will almost guarantee that the chunk_producer on stream 1 gets + # scheduled first. Once the first chunk_producer is scheduled in + # the correct order, there's very little room for the scheduling + # order of subsequent kernels to be inconsistent across ranks. + if step == 2: + torch.cuda._sleep(100) + chunk_producer((rank + step) % group_size, p2p_buf) + symm_mem.barrier(channel=step % 2) + out_chunks[remote_rank].copy_(remote_p2p_buf) + # The local P2P buffer can only be overwritten by the next + # chunk_producer after all peers have finished reading from it. + symm_mem.barrier(channel=step % 2) + + # If the sleep wasn't issued in the above loop, do it now. + if group_size == 2: + torch.cuda._sleep(100) + + chunk_producer(rank, out_chunks[rank]) + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +lib = torch.library.Library("_ap_symm_mem", "DEF") # noqa: TOR901 +lib.define( + "fused_all_gather_matmul(" + "Tensor A, Tensor[] Bs, int gather_dim, str group_name, *, bool return_A = True) -> (Tensor?, Tensor[])", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_all_gather_scaled_matmul(" + "Tensor A, Tensor[] Bs, Tensor A_scale, Tensor[] B_scales, " + "int gather_dim, str group_name, " + "Tensor?[] biases, " + "Tensor?[] result_scales, " + "ScalarType?[] out_dtypes, " + "bool[] use_fast_accum) -> (Tensor, Tensor[])", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_scaled_matmul_reduce_scatter(" + "Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, " + "str reduce_op, int orig_scatter_dim, int scatter_dim_after_maybe_reshape, str group_name, int[]? output_shape, " + "Tensor? bias = None, " + "Tensor? result_scale = None, " + "ScalarType? out_dtype = None, " + "bool use_fast_accum = False) -> Tensor", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor") +lib.define( + "_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor" +) + + +class _ScaleMode(Enum): + UNSCALED = "unscaled" + TENSOR_WISE = "tensor-wise" + ROW_WISE_SHARDED = "row-wise-sharded" + ROW_WISE_REPLICATED = "row-wise-replicated" + + +def _check_and_verify_fp8_all_gather_scale_mode( + shard: torch.Tensor, scale: torch.Tensor | None, gather_dim: int, group_size: int +) -> _ScaleMode: + full_shape = list(shard.shape) + full_shape[gather_dim] *= group_size + + if scale is None: + return _ScaleMode.UNSCALED + elif scale.shape[:-1] == shard.shape[:-1] and scale.shape[-1] == 1: + # Row-wise scaling + # + # NOTE: when the last dim of both A_shard and A_scale is one, we can't + # tell if A_scale is replicated tensor-wise scale or sharded row-wise + # scale. Treating it as row-wise scaling for safety. + return _ScaleMode.ROW_WISE_SHARDED + elif scale.numel() == 1: + return _ScaleMode.TENSOR_WISE + elif list(scale.shape[:-1]) == full_shape[:-1]: + return _ScaleMode.ROW_WISE_REPLICATED + else: + raise ValueError( + "Invalid scale shape for fp8 all-gather " + f"(shard shape: {shard.shape}, scale shape: {scale.shape})" + ) + + +def _fused_all_gather_matmul_impl( + mm_out_op: torch._ops.OpOverload, + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor | None, + kwargs_list: list[dict[str, Any]], + out_dtypes: list[torch.dtype | None], + gather_dim: int, + group_name: str, + return_A: bool, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + if A_shard.dim() < 2: + raise ValueError("A_shard must be a matrix") + for B in Bs: + if B.dim() != 2: + raise ValueError("B must be a matrix") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_types) must be the same as len(Bs)") + if len(kwargs_list) != len(Bs): + raise ValueError("len(kwargs_list) must be the same as len(Bs)") + if gather_dim < 0 or gather_dim >= A_shard.dim(): + raise ValueError("Invalid gather_dim") + + group = c10d._resolve_process_group(group_name) + + if gather_dim == A_shard.ndim - 1: + # Implementation for gathering on last dimension of matmul (N) + # A_shard splitted column wise + # A_shard: [A0, A1, ... , Ags] + return _fused_all_gather_matmul_last_gather_dim_impl( + mm_out_op, + A_shard, + Bs, + A_scale, + kwargs_list, + out_dtypes, + gather_dim, + group_name, + return_A, + ) + + # Move the gather_dim to the front and flatten the tensor into a 2D matrix. + # The flattened tensor doesn't need to be contiguous (for computation + # efficiency), as _pipelined_all_gather_and_consume guarantees that shards + # passed to shard_consumer are contiguous. + A_shard_flat = A_shard.movedim(gather_dim, 0) + leading_dims = [group.size()] + list(A_shard_flat.shape[:-1]) + A_shard_flat = A_shard_flat.flatten(0, -2) + + # Helper function for reverting the above transformation + def unflatten(t: torch.Tensor) -> torch.Tensor: + return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim) + + A_flat = A_shard_flat.new_empty( + A_shard_flat.shape[0] * group.size(), + A_shard_flat.shape[1], + ) + + outputs = [ + A_flat.new_empty(A_flat.shape[0], B.shape[1], dtype=out_dtype or B.dtype) + for B, out_dtype in zip(Bs, out_dtypes) + ] + output_shards = [output.chunk(group.size()) for output in outputs] + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group.size() + ) + + # Computing block-wise matmul along the first dim of A + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + assert A_scale is not None + A_scale_shard = A_scale.movedim(gather_dim, 0).flatten(0, -2) + A_scale_flat = A_scale_shard.new_empty( + A_scale_shard.shape[0] * group.size(), + A_scale_shard.shape[1], + ) + + def row_wise_sharded_consumer(shard: list[torch.Tensor], rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard[0], + B, + scale_a=shard[1], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_multi_all_gather_and_consume( + [A_shard_flat, A_scale_shard], + row_wise_sharded_consumer, + [A_flat, A_scale_flat], + group_name, + return_A, + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + assert A_scale is not None + A_scale_shards = ( + A_scale.movedim(gather_dim, 0).flatten(0, -2).chunk(group.size()) + ) + + def row_wise_replicated_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard, + B, + scale_a=A_scale_shards[rank], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_all_gather_and_consume( + A_shard_flat, + row_wise_replicated_consumer, + A_flat, + group_name, + return_A, + ) + else: + if scale_mode == _ScaleMode.TENSOR_WISE: + assert A_scale is not None + for kwargs in kwargs_list: + kwargs["scale_a"] = A_scale + else: + assert scale_mode == _ScaleMode.UNSCALED + + def default_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank]) + + _pipelined_all_gather_and_consume( + A_shard_flat, + default_consumer, + A_flat, + group_name, + return_A, + ) + + A = unflatten(A_flat) if return_A else None + return A, [unflatten(output) for output in outputs] + + +def _pipelined_all_gather_and_consume_last_dim( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group_name: str, + ag_out_needed: bool = True, +) -> None: + p2p_workspace_size_req = 0 + p2p_workspace_size_req = shard.numel() * shard.element_size() + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + def copy_shard(dst: torch.Tensor, src: torch.Tensor) -> None: + dst.copy_(src) + + def get_p2p_buf(remote_rank: int) -> torch.Tensor: + buf = symm_mem.get_buffer( + remote_rank, + shard.shape, + shard.dtype, + ) + return buf + + local_p2p_buf = get_p2p_buf(rank) + + # shards[i] => shard from rank i + shards = [] + for i, c in enumerate(ag_out.chunk(group_size)): + shards.append(c) + + copy_shard(dst=local_p2p_buf, src=shard) + symm_mem.barrier(channel=1) + backend_stream.wait_stream(torch.cuda.current_stream()) + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + shard_consumer(shard, rank) + + for step in range(1, group_size): + if step % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + remote_rank = (step + rank) % group_size + remote_p2p_buf = get_p2p_buf(remote_rank) + with stream: + copy_shard(dst=shards[remote_rank], src=remote_p2p_buf) + shard_consumer(shards[remote_rank], remote_rank) + + if ag_out_needed: + # Copy from input to the all-gather output. Opportunistically overlap + # it with the last shard_consumer. + if group_size % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + with stream: + copy_shard(dst=shards[rank], src=shard) + + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +def _fused_all_gather_matmul_last_gather_dim_impl( + mm_out_op: torch._ops.OpOverload, + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor | None, + kwargs_list: list[dict[str, Any]], + out_dtypes: list[torch.dtype | None], + gather_dim: int, + group_name: str, + return_A: bool, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + assert gather_dim == A_shard.ndim - 1 + group = c10d._resolve_process_group(group_name) + + B_shards = [B.chunk(group.size()) for B in Bs] + + leading_dims = list(A_shard.shape[:-1]) + A_shard_flat = A_shard.flatten(0, -2) + + def unflatten(t: torch.Tensor) -> torch.Tensor: + return t.view(*leading_dims, -1) + + A_out_leading_dims = list(A_shard.shape[:-1]) + A_out_leading_dims[0] *= group.size() + + def unflatten_A_out(t: torch.Tensor) -> torch.Tensor: + return t.view(*A_out_leading_dims, -1) + + A_flat_out = A_shard_flat.new_empty( + A_shard_flat.shape[0] * group.size(), + A_shard_flat.shape[1], + ) + + # Outputs work as accumulator for output_partials + outputs = [ + torch.empty( + (A_shard_flat.shape[0], B.shape[1]), + dtype=out_dtype or B.dtype, + device=A_shard.device, + ) + for B, out_dtype in zip(Bs, out_dtypes) + ] + + # Additional allocation for partials output, + # That will be reduced into output. + output_partials = [torch.empty_like(out) for out in outputs] + + first = True + + def default_consumer(shard: torch.Tensor, rank: int) -> None: + nonlocal first + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + out = outputs[idx] if first else output_partials[idx] + mm_out_op(shard, B_shards[idx][rank], **kwargs, out=out) + if not first: + outputs[idx] += output_partials[idx] + first = False + + _pipelined_all_gather_and_consume_last_dim( + A_shard_flat, + default_consumer, + A_flat_out, + group_name, + return_A, + ) + + A = unflatten_A_out(A_flat_out) if return_A else None + return A, [unflatten(output) for output in outputs] + + +@torch.library.impl(lib, "fused_all_gather_matmul", "Meta") +def _fused_all_gather_matmul_fallback( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: str, + *, + return_A: bool = True, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + if gather_dim == A.ndim - 1: + A_splits = A.chunk(group_size) + A_mm = torch.cat(A_splits, dim=-1) + res = [torch.matmul(A_mm, B) for B in Bs] + if return_A: + return A, res + else: + return None, res + + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs] + if return_A: + return A.movedim(0, gather_dim), res + else: + return None, res + + +@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA") +def _fused_all_gather_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: str, + *, + return_A: bool = True, +) -> tuple[torch.Tensor | None, list[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + all_gather_tensor(A_shard, gather_dim, group_name) @ B + + Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + if _is_test_mode: + return _fused_all_gather_matmul_fallback( + A_shard, Bs, gather_dim, group_name, return_A=return_A + ) + + if _should_use_fused_all_gather_matmul_native(A_shard, Bs, gather_dim, group_name): + group = c10d._resolve_process_group(group_name) + leading_dims = list(A_shard.shape[:-1]) + leading_dims[0] *= group.size() + A, out = _fused_all_gather_matmul_native( + A_shard.flatten(0, -2), Bs[0], group_name + ) + return A.view(*leading_dims, -1), [out.view(*leading_dims, -1)] + + if _should_use_multimem_all_gather_matmul( + A_shard, gather_dim, group_name, return_A + ): + return None, _multimem_all_gather_matmul(A_shard, Bs, group_name) + + with torch.profiler.record_function("fused_all_gather_matmul"): + return _fused_all_gather_matmul_impl( + torch.ops.aten.mm.out, + A_shard, + Bs, + None, + [{} for B in Bs], + [B.dtype for B in Bs], + gather_dim, + group_name, + return_A, + ) + + +def _should_use_fused_all_gather_matmul_native( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: str, +) -> bool: + group = c10d._resolve_process_group(group_name) + local_M = math.prod(A_shard.shape[:-1]) + + return ( + "TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP" in os.environ + and A_shard.is_contiguous() + and gather_dim == 0 + # _async_input_mm requires local_M to be divisible by world_size. + and local_M % group.size() == 0 + # _async_input_mm outperforms the decomposition-based approach when the + # global M is small. + and 2048 < local_M * group.size() <= 4096 + # _async_input_mm only supports a single B. + and len(Bs) == 1 + ) + + +def _fused_all_gather_matmul_native( + A_shard: torch.Tensor, + B: torch.Tensor, + group_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + symm_mem = rendezvous(A_shard, group_name) + if symm_mem is None: + symm_mem = get_symm_mem_workspace( + group_name, A_shard.numel() * A_shard.element_size() + ) + symm_mem.barrier() + buf = symm_mem.get_buffer(symm_mem.rank, A_shard.shape, A_shard.dtype) + buf.copy_(A_shard) + A_shard = buf + + rank = symm_mem.rank + world_size = symm_mem.world_size + + current_stream = torch.cuda.current_stream() + backend_stream = _get_backend_stream(priority=-1) + + symm_mem.barrier() + backend_stream.wait_stream(current_stream) + current_stream.wait_stream(backend_stream) + + A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1]) + A_signals = torch.zeros(world_size, dtype=torch.uint32, device=A_shard.device) + A_shards = A.chunk(world_size) + + A_shards[rank].copy_(A_shard) + if not torch.cuda.is_current_stream_capturing(): + _SymmetricMemory.stream_write_value32(A_signals, rank, 1) + else: + _SymmetricMemory.memset32(A_signals, offset=rank, val=1, count=1) + + out = torch.ops.symm_mem._async_input_mm(A, B, A_signals, rank) + for step in range(1, world_size): + src_rank = (rank + step) % world_size + src_buf = symm_mem.get_buffer(src_rank, A_shard.shape, A_shard.dtype) + with backend_stream: + A_shards[src_rank].copy_(src_buf) + if not torch.cuda.is_current_stream_capturing(): + # cuStreamWriteValue32 issues a system level fence before the write + _SymmetricMemory.stream_write_value32(A_signals, src_rank, 1) + else: + _SymmetricMemory.memset32(A_signals, offset=src_rank, val=1, count=1) + + current_stream.wait_stream(backend_stream) + backend_stream.wait_stream(current_stream) + + symm_mem.barrier() + return A, out + + +def _should_use_multimem_all_gather_matmul( + A_shard: torch.Tensor, + gather_dim: int, + group_name: str, + return_A: bool, +) -> bool: + group = c10d._resolve_process_group(group_name) + local_M = math.prod(A_shard.shape[:-1]) + has_multicast_support = ( + A_shard.device.type == "cuda" + and _SymmetricMemory.has_multicast_support( + DeviceType.CUDA, A_shard.device.index + ) + ) + + return ( + has_multicast_support + and not return_A + and A_shard.is_contiguous() + and gather_dim == 0 + # The heuristic is empirical. We could refine it with a more + # sophisticated perf model. + and local_M * group.size() <= 2048 + ) + + +def _multimem_all_gather_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + group_name: str, +) -> list[torch.Tensor]: + group = c10d._resolve_process_group(group_name) + A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:])) + symm_mem = get_symm_mem_workspace( + group_name, A_shape.numel() * A_shard.element_size() + ) + A = symm_mem.get_buffer(symm_mem.rank, A_shape, A_shard.dtype) + torch.ops.symm_mem.multimem_all_gather_out(A_shard, group_name, A) + return [torch.matmul(A, B) for B in Bs] + + +@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta") +def _fused_all_gather_scaled_matmul_fallback( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor, + B_scales: list[torch.Tensor], + gather_dim: int, + group_name: str, + biases: list[torch.Tensor | None], + result_scales: list[torch.Tensor | None], + out_dtypes: list[torch.dtype | None], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group_size + ) + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + A_scale_shard = A_scale + A_scale = torch.ops._c10d_functional.all_gather_into_tensor( + A_scale.contiguous(), group_size, group_name + ) + A_scale = torch.ops._c10d_functional.wait_tensor(A_scale) + A_scale = ( + A_scale.view(group_size, *A_scale_shard.shape) + .movedim(gather_dim + 1, 1) + .flatten(0, -2) + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + A_scale = A_scale.movedim(gather_dim, 0).flatten(0, -2) + else: + assert scale_mode == _ScaleMode.TENSOR_WISE + + def scaled_matmul( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + bias: torch.Tensor | None, + result_scale: torch.Tensor | None, + out_dtype: torch.dtype | None, + use_fast_accum: bool, + ) -> torch.Tensor: + leading_dims = A.shape[:-1] + res = torch.ops.aten._scaled_mm( + A.flatten(0, -2), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + return res.unflatten(0, leading_dims) + + return A.movedim(0, gather_dim), [ + scaled_matmul( + A, B, A_scale, B_scale, bias, result_scale, out_dtype, fast_accum + ).movedim(0, gather_dim) + for B, B_scale, bias, result_scale, out_dtype, fast_accum in zip( + Bs, B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ] + + +@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA") +def _fused_all_gather_scaled_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor, + B_scales: list[torch.Tensor], + gather_dim: int, + group_name: str, + biases: list[torch.Tensor | None], + result_scales: list[torch.Tensor | None], + out_dtypes: list[torch.dtype | None], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + A = all_gather_tensor(A_shard, gather_dim, group_name) + leading_dims = A.shape[:-1] + res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale) + res = res.unflatten(0, leading_dims) + + The input `A_scale` can be tensor-wise, row-wise-sharded or + row-wise-replicated. + + Optimal stride order for `A_shard` - if `A_shard.movedim(gather_dim, 0)` is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + if len(biases) != len(Bs): + raise ValueError("len(biases) must be the same as len(Bs)") + if len(result_scales) != len(Bs): + raise ValueError("len(result_scales) must be the same as len(Bs)") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_dtypes) must be the same as len(Bs)") + if len(use_fast_accum) != len(Bs): + raise ValueError("len(use_gast_accum_list) must be the same as len(Bs)") + + if _is_test_mode: + return _fused_all_gather_scaled_matmul_fallback( + A_shard, + Bs, + A_scale, + B_scales, + gather_dim, + group_name, + biases, + result_scales, + out_dtypes, + use_fast_accum, + ) + + with torch.profiler.record_function("fused_all_gather_scaled_matmul"): + A, res = _fused_all_gather_matmul_impl( + torch.ops.aten._scaled_mm.out, + A_shard, + Bs, + A_scale, + [ + { + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": fast_accum, + } + for B_scale, bias, result_scale, out_dtype, fast_accum in zip( + B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ], + out_dtypes, + gather_dim, + group_name, + True, + ) + assert A is not None + return A, res + + +def make_contiguous_for_perm( + t: torch.Tensor, + perm: list[int], +) -> torch.Tensor: + """ + Restride `t` such that `t.permute(perm)` is contiguous. + """ + inv_perm = [0] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + return t.permute(perm).contiguous().permute(inv_perm) + + +def restride_A_shard_for_fused_all_gather_matmul( + t: torch.Tensor, + gather_dim: int, +) -> torch.Tensor: + """ + Restride the `A_shard` arg of `fused_all_gather_matmul` for optimal perf. + See the doc for `fused_all_gather_matmul` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(gather_dim)) + return make_contiguous_for_perm(t, perm) + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA") +def _fused_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + """ + Perform the following logic with micro-pipelined computation and + communication: + + reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + + Optimal stride order for A - if A.movedim(scatter_dim, 0) is contiguous, no + extra copy is required for input layout transformation. Otherwise A needs + to be copied once. + """ + if _is_test_mode: + return _fused_matmul_reduce_scatter_fallback( + A, B, reduce_op, scatter_dim, group_name + ) + + with torch.profiler.record_function("fused_matmul_reduce_scatter"): + return _fused_matmul_reduce_scatter_impl( + mm_out_op=torch.ops.aten.mm.out, + A=A, + B=B, + kwargs={}, + out_dtype=A.dtype, + reduce_op=reduce_op, + scatter_dim=scatter_dim, + group_name=group_name, + ) + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta") +def _fused_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + res = funcol.wait_tensor(res) + return res + + +def _fused_matmul_reduce_scatter_impl( + mm_out_op: torch._ops.OpOverload, + A: torch.Tensor, + B: torch.Tensor, + kwargs: dict[str, Any], + out_dtype: torch.dtype | None, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if scatter_dim < 0 or scatter_dim >= A.dim(): + raise ValueError("Invalid gather_dim") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + group = c10d._resolve_process_group(group_name) + out_shape = [*A.shape[:-1], B.shape[1]] + out_shape[scatter_dim] //= group.size() + + if scatter_dim == A.ndim - 1: + B_shards = B.chunk(group.size(), dim=B.ndim - 1) + A_flat = A.flatten(0, -2) + + def _chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_flat, B_shards[rank], **kwargs, out=out) + + leading_dims = list(A.shape[:-1]) + + stacked_partials = torch.empty( + (A_flat.shape[0], B.shape[1]), + dtype=out_dtype or A.dtype, + device=A.device, + ) + + _pipelined_produce_and_all2all( + _chunk_producer, + stacked_partials, + group_name, + out_chunk_dim=1, + ) + + stacked_partials_view = stacked_partials.reshape( + *leading_dims, group.size(), -1 + ) + return reduce_fn( + stacked_partials_view, + dim=-2, + ) + + # Move the scatter_dim to the front and flatten the tensor into a 2D matrix + x = A.movedim(scatter_dim, 0) + leading_dims = [group.size()] + list(x.shape[:-1]) + leading_dims[1] //= group.size() + x = x.flatten(0, -2) + A_shards = x.chunk(group.size()) + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_shards[rank], B, **kwargs, out=out) + + stacked_partials = x.new_empty(x.shape[0], B.shape[1], dtype=out_dtype or A.dtype) + + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group_name, + ) + + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + return reduce_fn( + stacked_partials.view(*leading_dims, -1) + .movedim(1, scatter_dim + 1) + .movedim(0, scatter_dim), + dim=scatter_dim, + ) + + +@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "CUDA") +def _fused_scaled_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + if _is_test_mode: + return _fused_scaled_matmul_reduce_scatter_fallback( + A, + B, + A_scale, + B_scale, + reduce_op, + orig_scatter_dim, + scatter_dim_after_maybe_reshape, + group_name, + output_shape, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + with torch.profiler.record_function("fused_scaled_matmul_reduce_scatter"): + return _fused_scaled_matmul_reduce_scatter_impl( + mm_out_op=torch.ops.aten._scaled_mm.out, + A=A, + B=B, + A_scale=A_scale, + kwargs={ + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": use_fast_accum, + }, + out_dtype=out_dtype, + reduce_op=reduce_op, + orig_scatter_dim=orig_scatter_dim, + scatter_dim_after_maybe_reshape=scatter_dim_after_maybe_reshape, + group_name=group_name, + output_shape=output_shape, + ) + + +@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "Meta") +def _fused_scaled_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: torch.Tensor | None = None, + result_scale: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + if A_scale.numel() > 1: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.flatten(0, -2).contiguous() + elif A_scale.numel() != 1: + raise ValueError( + "Invalid A_scale shape " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + + C = torch._scaled_mm( + A.flatten(0, -2).contiguous(), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + C = C.view(*output_shape[:-1], B.shape[1]) + res = funcol.reduce_scatter_tensor( + C, + reduce_op, + orig_scatter_dim, # need original scatter dim for 3D+ output tensor here + group_name, + ) + res = funcol.wait_tensor(res) + return res + + +def _fused_scaled_matmul_reduce_scatter_impl( + mm_out_op: torch._ops.OpOverload, + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + kwargs: dict[str, Any], + out_dtype: torch.dtype | None, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], +) -> torch.Tensor: + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if ( + scatter_dim_after_maybe_reshape < 0 + or scatter_dim_after_maybe_reshape >= A.dim() + ): + raise ValueError("Invalid scatter dim for 2D tensor input to scaled_mm") + if orig_scatter_dim < 0 or orig_scatter_dim >= len(output_shape): + raise ValueError("Invalid scatter dim for 3D+ output tensor") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + + group = c10d._resolve_process_group(group_name) + + # Move scatter to first dim, then shard the tensor along the first dim, so the chunk producer + # can perform matmuls along the first dim. + A_with_scatter_dim_0 = A.movedim(scatter_dim_after_maybe_reshape, 0) + + # To handle case where A is 3D+, reshape to 2D to prepare for mm which requires 2D inputs. + A_2D_with_scatter_dim_0 = A_with_scatter_dim_0.flatten(0, -2) + + # Partition A along the first dim to prepare for sharding across TP process group. + A_shards = A_2D_with_scatter_dim_0.chunk(group.size()) + + # Now that 'A' is sharded along the first dim, we need to update its scale(s) accordingly. + # How we do this depends on if we are using tensorwise scaling, rowwise scaling, or no scaling. + tensorwise_scaling = A_scale is not None and A_scale.numel() == 1 + rowwise_scaling = A_scale is not None and A_scale.numel() > 1 + + # For tensorwise scaling, the scale should be replicated so each shard has a copy. + if tensorwise_scaling: + A_scale_shards = [A_scale] * group.size() + + # For rowwise scaling, we need to move the scatter dim to the first dim to match the + # dim swap of the 'A' tensor. Then we can shard the scales along the first dim, just like + # the 'A' tensor. + elif rowwise_scaling: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = ( + A_scale.movedim(scatter_dim_after_maybe_reshape, 0) + .contiguous() + .flatten(0, -2) + ) + A_scale_shards = list(A_scale.chunk(group.size())) + # cuBLAS's row-wise kernel requires scales to be aligned to 16 bytes. + # When we slice them we might break this and need to reallocate them. + A_scale_shards = [ + t if t.data_ptr() % 16 == 0 else t.clone() for t in A_scale_shards + ] + else: + raise ValueError("A_scale cannot be none for scaled_mm") + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_shards[rank], B, scale_a=A_scale_shards[rank], **kwargs, out=out) + + # Stacked partials will be the 2D outputs of the the pipelined scaled mm, and will + # have the shape (A_with_scatter_dim_0_tensor.shape[0], B.shape[1]) to align with the formula: + # (a*b,c) @ (c,d) = (a*b,d) + stacked_partials = A_with_scatter_dim_0.new_empty( + A_2D_with_scatter_dim_0.shape[0], B.shape[1], dtype=out_dtype or A.dtype + ) + + # Execute the pipelined mm/scaled_mm. + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group_name, + ) + + # We now need to transform the *unreduced* stacked 2D partial mm outputs to an *unreduced* 3D+ output, + # then reduce-scatter. To do this, we first need to determine the shape of the unreduced 3D+ output, + # to reshape our stacked partials so we can apply the reduce-scatter. + # + # The *unreduced* 3D+ tensor will have dim 0 = `group_size`, as we have `group_size` instances of + # stacked partial outputs. The next dims will be A's leading dims (sharded along the original scatter dim), + # as it was the left operand of the mm op. We can use -1 as the final dim of the view to populate the rest. + stacked_partials_3D_leading_dims = [group.size()] + list( + # We use A from after the dim swap 0<=>scatter_dim, but before the flatten, + # to get the leading dims of the 3D+ view of stacked partials. + A_with_scatter_dim_0.shape[:-1] + ) + + # The `group_size` leading dim has been prepended to `stacked_partials_3D_leading_dims`, + # to capture the partial output from each rank. We need to divide the sharding/scatter dim + # by the group size. If the original scatter dim was 0, then it is now dim 1 in this + # tensor, since this new `group_size` dim was prepended. + stacked_partial_scatter_dim = orig_scatter_dim if orig_scatter_dim > 0 else 1 + stacked_partials_3D_leading_dims[stacked_partial_scatter_dim] //= group.size() + + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + reduced_out = reduce_fn( + # View 2D stacked partials as 3D+ tensor of shape (`group_size`, ...) + stacked_partials.view(*stacked_partials_3D_leading_dims, -1) + # We originally swapped 0<=>scatter_dim_after_maybe_reshape. Now after + # prepending the `group_size` dim, to undo this original swap, we + # must swap 1<=>scatter_dim_after_maybe_reshape+1. + .movedim(1, scatter_dim_after_maybe_reshape + 1), + # Reduce along the `group_size` dim (0). + dim=0, + ) + + # Output shape must be scattered along original scatter dim as well. + output_shape[orig_scatter_dim] //= group.size() + out = reduced_out.view(*output_shape) + return out + + +def restride_A_for_fused_matmul_reduce_scatter( + t: torch.Tensor, + scatter_dim: int, +) -> torch.Tensor: + """ + Restride the `A_shard` arg of `fused_matmul_reduce_scatter` for optimal + perf. See the doc for `fused_matmul_reduce_scatter` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(scatter_dim)) + return make_contiguous_for_perm(t, perm) + + +def _maybe_convert_scalar_types_to_dtypes( + scalar_types: list[Any], +) -> list[torch.dtype | None]: + """ + When a list of `torch.dtype`s is passed through the dispatcher as + `ScalarType[]`, it is converted to a list of scalar type enum values. This + function converts it back to a list of `torch.dtype`s. + """ + # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h + _SCALAR_TYPE_TO_DTYPE = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.short, + 3: torch.int, + 4: torch.int64, + 5: torch.half, + 6: torch.float, + 7: torch.double, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex128, + 11: torch.bool, + 12: torch.qint8, + 13: torch.quint8, + 14: torch.qint32, + 15: torch.bfloat16, + 16: torch.float8_e5m2, + 17: torch.float8_e4m3fn, + 18: torch.float8_e5m2fnuz, + 19: torch.float8_e4m3fnuz, + } + if any(not isinstance(x, (type(None), int)) for x in scalar_types): + return scalar_types + + dtypes: list[torch.dtype | None] = [] + for scalar_type in scalar_types: + if scalar_type is None: + dtypes.append(scalar_type) + elif scalar_type not in _SCALAR_TYPE_TO_DTYPE: + raise ValueError("Unrecognized scalar type {scalar_type}") + else: + dtypes.append(_SCALAR_TYPE_TO_DTYPE[scalar_type]) + return dtypes + + +class Work(_Work): + def __init__(self) -> None: + super().__init__() + self.event = torch.cuda.Event() + self.event.record() + + def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool: + self.event.wait() + return True + + +""" +NOTE [low-contention collectives] +When a collective is overlapped with abundant compute, it makes sense to +prioritize reducing the contention between the collective and the overlapped +compute, even at the cost of a slightly slower collective. + +Common collective implementations (e.g., NCCL without user buffer +registration) optimize for throughput with no ambient compute. However, such +implementations may not be optimal when they are overlapped with compute: +- These implementations typically fuse the entire collective into a single +kernel and reserve SM resources based on the most demanding portion of the +collective, even when a large portion of the collective does not require this +much resource. +- These implementations often use SM-based P2P copy as opposed to copy +engine-based P2P copy. Copy engine-based P2P copy may not have a significant +advantage when there's no ambient compute. However, it may significantly +improve overall resource utilization in the presence of ambient compute. + +When overlapped with intensive compute (e.g., persistent matmul kernels), the +SM-usage of a collective can lead to inefficient overlapping. + +Low-contention collectives achieve their goals with the following strategies: +- Use copy engine-based copy whenever possible. +- Break down portions of a collective with different resource requirements +into multiple kernels. This improves the overlapping efficiency at the cost +of additional launching overhead. +""" + + +@torch.library.impl(lib, "_low_contention_all_gather", "Meta") +def _low_contention_all_gather_meta( + tensor: torch.Tensor, + group_name: str, +) -> torch.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:]) + + +@torch.library.impl(lib, "_low_contention_all_gather", "CUDA") +def _low_contention_all_gather( + tensor: torch.Tensor, + group_name: str, +) -> torch.Tensor: + """ + Performs all-gather with symmetric memory in a low-contention fashion. + + When `tensor` is already in symmetric memory: + - The collective is carried out without using SMs. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - An extra SM-based copy is performed to copy the input data into the + symmetric memory workspace. + - Symmetric memory workspace size requirement: the size of `tensor`. + """ + symm_mem = rendezvous(tensor, group_name) + if symm_mem is not None: + input_is_symm_mem = True + else: + symm_mem = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + input_is_symm_mem = False + + rank = symm_mem.rank + world_size = symm_mem.world_size + + output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:]) + chunks = output.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + if not input_is_symm_mem: + local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype) + local_buf.copy_(tensor) + # pull + symm_mem.barrier() + for step in range(0, world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + _register_work(output, Work()) + return output + + +@torch.library.impl(lib, "_low_contention_reduce_scatter", "Meta") +def _low_contention_reduce_scatter_meta( + tensor: torch.Tensor, + reduce_op: str, + group_name: str, +) -> torch.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.unflatten(0, (group_size, -1)).mean(dim=0) + + +def _low_contention_reduce_scatter_with_symm_mem_input( + tensor: torch.Tensor, + reduce_op: str, + symm_mem: _SymmetricMemory, +) -> torch.Tensor: + rank = symm_mem.rank + world_size = symm_mem.world_size + + assert tensor.shape[0] % world_size == 0 + a2a_res = torch.empty_like(tensor) + chunks = a2a_res.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + # pull + offline reduction + symm_mem.barrier() + for step in range(0, world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer( + remote_rank, + chunks[0].shape, + chunks[0].dtype, + chunks[0].numel() * rank, + ) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + + ret = a2a_res.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + _register_work(ret, Work()) + return ret + + +def _low_contention_reduce_scatter_with_workspace( + tensor: torch.Tensor, + reduce_op: str, + workspace: _SymmetricMemory, +) -> torch.Tensor: + rank = workspace.rank + world_size = workspace.world_size + + assert tensor.shape[0] % world_size == 0 + chunks = tensor.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + # push + offline reduction + workspace.barrier() + for step in range(0, world_size): + remote_rank = (rank - step) % world_size + dst_buf = workspace.get_buffer( + remote_rank, chunks[0].shape, chunks[0].dtype, chunks[0].numel() * rank + ) + dst_buf.copy_(chunks[remote_rank]) + workspace.barrier() + + buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) + ret = buf.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + _register_work(ret, Work()) + return ret + + +@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA") +def _low_contention_reduce_scatter( + tensor: torch.Tensor, + reduce_op: str, + group_name: str, +) -> torch.Tensor: + """ + Performs reduce-scatter with symmetric memory in a low-contention fashion. + + This implementation performs a P2P-based all-to-all followed by an offline + reduction. + + When `tensor` is already in symmetric memory: + - Pull-based all-to-all is used. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - Push-based all-to-all is used. + - Symmetric memory workspace size requirement: the size of `tensor`. + + SM-usage: + - SM-based copy of the rank's own chunk for the all-to-all. + - Reduction on the all-to-all result. + + TODO(yifu): the SM-based copy can be avoided with a list-based reduction + kernel. + """ + symm_mem = rendezvous(tensor, group_name) + if symm_mem is not None: + return _low_contention_reduce_scatter_with_symm_mem_input( + tensor, reduce_op, symm_mem + ) + else: + workspace = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + return _low_contention_reduce_scatter_with_workspace( + tensor, reduce_op, workspace + ) diff --git a/examples/example_llama3.py b/examples/example_llama3.py index df57751f..cfeefe5c 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -43,6 +43,7 @@ device = torch.device("cuda") model_type = "8b" +enable_asynctp = False def model_fn(): @@ -191,6 +192,25 @@ def add_tp_constraints(autop): if enable_manual_constraint and not use_1d_mesh: add_tp_constraints(autop) + if enable_asynctp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + enable_symm_mem_for_group(mesh["dp"].get_group().group_name) + enable_symm_mem_for_group(mesh["tp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = False + from autoparallel.asynctp import micro_pipeline_tp_pass + + existing_post_grad_custom_post_pass = ( + torch._inductor.config.post_grad_custom_post_pass + ) + + def _pass(graph): + if existing_post_grad_custom_post_pass is not None: + existing_post_grad_custom_post_pass(graph) + micro_pipeline_tp_pass(graph) + + torch._inductor.config.post_grad_custom_post_pass = _pass + t = time.time() sharding_placement = autop.optimize_placement() print(f"Took {time.time() - t:.2f} s")