From f42b103ec870dd521c2942e4d8c8885e643242ea Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 4 Nov 2025 10:06:06 -0800 Subject: [PATCH 1/3] redistribtue cost --- autoparallel/collective_runtime_estimation.py | 60 +++++++++++++------ 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/autoparallel/collective_runtime_estimation.py b/autoparallel/collective_runtime_estimation.py index cbffcd1..39fbf65 100644 --- a/autoparallel/collective_runtime_estimation.py +++ b/autoparallel/collective_runtime_estimation.py @@ -8,9 +8,9 @@ import torch.distributed.tensor._dtensor_spec as dtensor_spec from torch._prims_common import check_contiguous_sizes_strides from torch.distributed.tensor._collective_utils import ( - MeshTopoInfo, allgather_cost, allreduce_cost, + MeshTopoInfo, reduce_scatter_cost, spec_to_bytes, ) @@ -63,29 +63,53 @@ def redistribute_cost( mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) cost = 0.0 - comm_bytes_gb = ( - spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 + import math + + from torch.distributed._functional_collectives import _are_we_tracing + from torch.distributed.tensor._redistribute import ( + _gen_transform_infos, + _gen_transform_infos_non_cached, ) - # Transformation that considered for redistribute cost: - # 1. allgather 2. alltoall - # 3. allreduce 4. reduce_scatter - curr_placements = [current_spec.placements[i] for i in order] - tgt_placements = [target_spec.placements[i] for i in order] + + if _are_we_tracing(): + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) + + # Transformation that considered for redistribute cost: + # 1. allgather 2. alltoall + # 3. allreduce 4. reduce_scatter + # curr_placements = [current_spec.placements[i] for i in order] + # tgt_placements = [target_spec.placements[i] for i in order] + # is_contiguous: bool = check_contiguous_sizes_strides( + # current_spec.shape, current_spec.stride + # ) + # for i, current, target in zip(order, curr_placements, tgt_placements): + # if not is_contiguous: + # cost += compute_read_write_time(comm_bytes_gb * 2 * 1024**3) is_contiguous: bool = check_contiguous_sizes_strides( current_spec.shape, current_spec.stride ) - for i, current, target in zip(order, curr_placements, tgt_placements): - if current == target: - continue - num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] + for transform_info in transform_infos: + comm_bytes_gb = ( + current_spec.tensor_meta.dtype.itemsize + * math.prod(transform_info.logical_shape) + / 1024 + / 1024 + / 1024 + ) if not is_contiguous: cost += compute_read_write_time(comm_bytes_gb * 2 * 1024**3) + current = transform_info.src_dst_placements[0] + target = transform_info.src_dst_placements[1] + if current == target: + continue + mesh_dim = transform_info.mesh_dim + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] if current.is_shard() and target.is_replicate(): current = cast(Shard, current) - # allgather gives larger comm bytes - comm_bytes_gb *= num_devices_on_mesh_dim # add up allgather comm cost - cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) if current.dim != 0: # penalize cases like S(1) -> R as there are additional compute cost # which corresponds to reshuffling the whole output tensor @@ -98,7 +122,7 @@ def redistribute_cost( target = cast(Shard, target) # should be alltoall comm, since we haven't implement it yet, add penalty # to favor allgather instead - cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us + cost += all_to_all_cost(comm_bytes_gb, mesh_topo, mesh_dim) # us num_copies = 0 if current.dim != 0: @@ -112,11 +136,11 @@ def redistribute_cost( elif current.is_partial() and target.is_replicate(): # add up allreduce comm cost - cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) + cost += allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim) elif current.is_partial() and target.is_shard(): target = cast(Shard, target) # add up reduce_scatter comm cost - cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim) if target.dim != 0: # penalize cases like P -> S(1) as there are additional compute cost # which corresponds to reshuffling the whole input tensor From 673ae1e2e8e0906f065bbb9ed4582fa227651f33 Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 4 Nov 2025 14:16:29 -0800 Subject: [PATCH 2/3] add is_contiguous --- autoparallel/collective_runtime_estimation.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/autoparallel/collective_runtime_estimation.py b/autoparallel/collective_runtime_estimation.py index 39fbf65..7818370 100644 --- a/autoparallel/collective_runtime_estimation.py +++ b/autoparallel/collective_runtime_estimation.py @@ -81,12 +81,6 @@ def redistribute_cost( # 3. allreduce 4. reduce_scatter # curr_placements = [current_spec.placements[i] for i in order] # tgt_placements = [target_spec.placements[i] for i in order] - # is_contiguous: bool = check_contiguous_sizes_strides( - # current_spec.shape, current_spec.stride - # ) - # for i, current, target in zip(order, curr_placements, tgt_placements): - # if not is_contiguous: - # cost += compute_read_write_time(comm_bytes_gb * 2 * 1024**3) is_contiguous: bool = check_contiguous_sizes_strides( current_spec.shape, current_spec.stride ) From 768891e445d9ea1c2082c5b73c91937927b8b8d3 Mon Sep 17 00:00:00 2001 From: mori360 Date: Tue, 4 Nov 2025 14:46:56 -0800 Subject: [PATCH 3/3] lint --- autoparallel/collective_runtime_estimation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/autoparallel/collective_runtime_estimation.py b/autoparallel/collective_runtime_estimation.py index 7818370..ec3f01e 100644 --- a/autoparallel/collective_runtime_estimation.py +++ b/autoparallel/collective_runtime_estimation.py @@ -8,11 +8,10 @@ import torch.distributed.tensor._dtensor_spec as dtensor_spec from torch._prims_common import check_contiguous_sizes_strides from torch.distributed.tensor._collective_utils import ( + MeshTopoInfo, allgather_cost, allreduce_cost, - MeshTopoInfo, reduce_scatter_cost, - spec_to_bytes, ) from torch.distributed.tensor.placement_types import Partial, Shard @@ -85,6 +84,9 @@ def redistribute_cost( current_spec.shape, current_spec.stride ) for transform_info in transform_infos: + assert ( + current_spec.tensor_meta is not None + ), "spec should have tensor meta defined!" comm_bytes_gb = ( current_spec.tensor_meta.dtype.itemsize * math.prod(transform_info.logical_shape)