diff --git a/autoparallel/collective_runtime_estimation.py b/autoparallel/collective_runtime_estimation.py index cbffcd1..ec3f01e 100644 --- a/autoparallel/collective_runtime_estimation.py +++ b/autoparallel/collective_runtime_estimation.py @@ -12,7 +12,6 @@ allgather_cost, allreduce_cost, reduce_scatter_cost, - spec_to_bytes, ) from torch.distributed.tensor.placement_types import Partial, Shard @@ -63,29 +62,50 @@ def redistribute_cost( mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) cost = 0.0 - comm_bytes_gb = ( - spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 + import math + + from torch.distributed._functional_collectives import _are_we_tracing + from torch.distributed.tensor._redistribute import ( + _gen_transform_infos, + _gen_transform_infos_non_cached, ) - # Transformation that considered for redistribute cost: - # 1. allgather 2. alltoall - # 3. allreduce 4. reduce_scatter - curr_placements = [current_spec.placements[i] for i in order] - tgt_placements = [target_spec.placements[i] for i in order] + + if _are_we_tracing(): + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) + + # Transformation that considered for redistribute cost: + # 1. allgather 2. alltoall + # 3. allreduce 4. reduce_scatter + # curr_placements = [current_spec.placements[i] for i in order] + # tgt_placements = [target_spec.placements[i] for i in order] is_contiguous: bool = check_contiguous_sizes_strides( current_spec.shape, current_spec.stride ) - for i, current, target in zip(order, curr_placements, tgt_placements): - if current == target: - continue - num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] + for transform_info in transform_infos: + assert ( + current_spec.tensor_meta is not None + ), "spec should have tensor meta defined!" + comm_bytes_gb = ( + current_spec.tensor_meta.dtype.itemsize + * math.prod(transform_info.logical_shape) + / 1024 + / 1024 + / 1024 + ) if not is_contiguous: cost += compute_read_write_time(comm_bytes_gb * 2 * 1024**3) + current = transform_info.src_dst_placements[0] + target = transform_info.src_dst_placements[1] + if current == target: + continue + mesh_dim = transform_info.mesh_dim + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] if current.is_shard() and target.is_replicate(): current = cast(Shard, current) - # allgather gives larger comm bytes - comm_bytes_gb *= num_devices_on_mesh_dim # add up allgather comm cost - cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) if current.dim != 0: # penalize cases like S(1) -> R as there are additional compute cost # which corresponds to reshuffling the whole output tensor @@ -98,7 +118,7 @@ def redistribute_cost( target = cast(Shard, target) # should be alltoall comm, since we haven't implement it yet, add penalty # to favor allgather instead - cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us + cost += all_to_all_cost(comm_bytes_gb, mesh_topo, mesh_dim) # us num_copies = 0 if current.dim != 0: @@ -112,11 +132,11 @@ def redistribute_cost( elif current.is_partial() and target.is_replicate(): # add up allreduce comm cost - cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) + cost += allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim) elif current.is_partial() and target.is_shard(): target = cast(Shard, target) # add up reduce_scatter comm cost - cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim) if target.dim != 0: # penalize cases like P -> S(1) as there are additional compute cost # which corresponds to reshuffling the whole input tensor