Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,42 @@ def einsum_rule(mesh, op_schema):
kwargs_schema={},
)
return _mm_like_strategy(mm_equation, mesh, new_op_schema)


@register_opschema_rule(torch.ops.aten.scatter.src)
def scatter_strategy(mesh, op_schema: OpSchema):
# taken from scatter_add strategy from PyTorch
from torch.distributed.tensor._ops._tensor_ops import (
PlacementList,
expand_to_full_mesh_op_strategy,
normalize_dim,
)

input_strategy = op_schema.args_schema[0]
dim = op_schema.args_schema[1]
index_strategy = op_schema.args_schema[2]

assert isinstance(input_strategy, OpStrategy)
assert isinstance(index_strategy, OpStrategy)
assert isinstance(dim, int)
dim = normalize_dim(dim, input_strategy.ndim)
mesh = input_strategy.mesh
input_shape = input_strategy.shape
index_shape = index_strategy.shape

single_mesh_dim_strategies = []

# placement list stores placements of [output, input, index, src]
# first we always have replicate all for inputs and output
all_replicate: PlacementList = [Replicate()] * 4
single_mesh_dim_strategies.append(all_replicate)

if len(input_shape) == len(index_shape):
for d in range(len(input_shape)):
if d != dim and input_shape[d] == index_shape[d]:
sharding: PlacementList = [Shard(d), Shard(d), Shard(d), Shard(d)]
single_mesh_dim_strategies.append(sharding)

return expand_to_full_mesh_op_strategy(
mesh, op_schema, single_mesh_dim_strategies, input_index=1
)