From 315f44bafb2ad215a3fcc6cdcf4ef17d5e8739a0 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 6 Aug 2025 15:34:54 +0000 Subject: [PATCH] Add gather and scatter_add strategies They were taken from https://github.com/pytorch-labs/autoparallel/pull/29 --- autoparallel/propagation_rules.py | 99 +++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index a3e309ff..bf34c40c 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -660,6 +660,105 @@ def index_rule(mesh, op_schema): return out_strat +@register_opschema_rule(torch.ops.aten.gather.default) +def gather_strategy(mesh, op_schema): + from torch.distributed.tensor._op_schema import PlacementList + from torch.distributed.tensor._ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy + + input_strategy = op_schema.args_schema[0] + dim = op_schema.args_schema[1] + index_strategy = op_schema.args_schema[2] + + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # input sharding, input sharded, index accepts mask partial, output follows index + # this only works when the input is sharded on the gather dimension, and + # index has size 1 on the gather dimension + if index_shape[dim] == 1: + index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) + input_sharding: PlacementList = [ + index_partial_placement, + Shard(dim), + index_partial_placement, + ] + single_mesh_dim_strategies.append(input_sharding) + + # index sharding, input replicated, index sharded, output follows index + # this only works when the sharding dimension is the gather dimension + index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] + single_mesh_dim_strategies.append(index_sharding) + + if len(input_shape) == len(index_shape): + for d in range(len(input_shape)): + if d != dim: + sharding = [Shard(d), Shard(d), Shard(d)] + single_mesh_dim_strategies.append(sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + +@register_opschema_rule(torch.ops.aten.scatter_add.default) +def scatter_add_strategy(mesh, op_schema): + from torch.distributed.tensor._op_schema import PlacementList + + # from torch.distributed.tensor._ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy + + input_strategy = op_schema.args_schema[0] + dim = op_schema.args_schema[1] + index_strategy = op_schema.args_schema[2] + # src_strategy = op_schema.args_schema[3] + + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + """ + # input sharding, input sharded, index accepts mask partial, output follows index + # this only works when the input is sharded on the gather dimension, and + # index has size 1 on the gather dimension + if index_shape[dim] == 1: + index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) + input_sharding: PlacementList = [ + index_partial_placement, + Shard(dim), + index_partial_placement, + ] + single_mesh_dim_strategies.append(input_sharding) + """ + # index sharding, input replicated, index sharded, output follows index + # this only works when the sharding dimension is the gather dimension + index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)] + single_mesh_dim_strategies.append(index_sharding) + + if len(input_shape) == len(index_shape): + for d in range(len(input_shape)): + if d != dim: + sharding = [Shard(d), Shard(d), Shard(d), Shard(d)] + single_mesh_dim_strategies.append(sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + def sdpa_rule(op, mesh, op_schema): out_strat = get_op_strategy(op, op_schema) # remove wrong context-parallel strategy