From 1b20da9a5e0c67d93a8b8e9a05ca94cc304e9af3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 29 Sep 2025 11:23:51 +0000 Subject: [PATCH] Add scatter propagation rule PyTorch only supports Replicate sharding. Add sharded replication rules as well, taken from scatter_add from PyTorch --- autoparallel/propagation_rules.py | 39 +++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index 68cb42a9..89118f9e 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -783,3 +783,42 @@ def einsum_rule(mesh, op_schema): kwargs_schema={}, ) return _mm_like_strategy(mm_equation, mesh, new_op_schema) + + +@register_opschema_rule(torch.ops.aten.scatter.src) +def scatter_strategy(mesh, op_schema: OpSchema): + # taken from scatter_add strategy from PyTorch + from torch.distributed.tensor._ops._tensor_ops import ( + PlacementList, + expand_to_full_mesh_op_strategy, + normalize_dim, + ) + + input_strategy = op_schema.args_schema[0] + dim = op_schema.args_schema[1] + index_strategy = op_schema.args_schema[2] + + assert isinstance(input_strategy, OpStrategy) + assert isinstance(index_strategy, OpStrategy) + assert isinstance(dim, int) + dim = normalize_dim(dim, input_strategy.ndim) + mesh = input_strategy.mesh + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index, src] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + if len(input_shape) == len(index_shape): + for d in range(len(input_shape)): + if d != dim and input_shape[d] == index_shape[d]: + sharding: PlacementList = [Shard(d), Shard(d), Shard(d), Shard(d)] + single_mesh_dim_strategies.append(sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + )