diff --git a/lm_engine/utils/parallel.py b/lm_engine/utils/parallel.py index 689710b4b..96992b8b2 100644 --- a/lm_engine/utils/parallel.py +++ b/lm_engine/utils/parallel.py @@ -15,6 +15,7 @@ from torch.distributed._symmetric_memory import enable_symm_mem_for_group from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from .accelerator import Accelerator from .miscellaneous import divide_if_divisible from .packages import is_torch_xla_available @@ -26,7 +27,7 @@ # general -_MESH: DeviceMesh | None = None +_DENSE_MESH: DeviceMesh | None = None _GLOBAL_RANK: int | None = None _LOCAL_RANK: int | None = None _WORLD_SIZE: int | None = None @@ -65,9 +66,7 @@ def __init__( timeout_minutes: int | None = None, use_async_tensor_parallel: bool = False, ) -> ProcessGroupManager: - from .accelerator import Accelerator - - global _MESH + global _DENSE_MESH global _TENSOR_PARALLEL_FIRST_RANK global _DATA_PARALLEL_REPLICATION_WORLD_SIZE global _DATA_PARALLEL_SHARDING_WORLD_SIZE @@ -129,7 +128,7 @@ def __init__( _DATA_PARALLEL_SHARDING_WORLD_SIZE = data_parallel_sharding_world_size # FIXME unable to use XLA mesh since XLA mesh doesn't support accessing submesh - _MESH = init_device_mesh( + _DENSE_MESH = init_device_mesh( "cpu" if accelerator == Accelerator.tpu else Accelerator.get_device_type(), ( pipeline_parallel_world_size, @@ -158,23 +157,19 @@ def is_initialized() -> bool: return torch.distributed.is_initialized() @staticmethod - def get_mesh() -> DeviceMesh: - global _MESH - return _MESH + def get_dense_mesh() -> DeviceMesh: + return _DENSE_MESH @staticmethod def get_global_rank() -> int: - global _GLOBAL_RANK return _GLOBAL_RANK @staticmethod def get_local_rank() -> int: - global _LOCAL_RANK return _LOCAL_RANK @staticmethod def get_world_size() -> int: - global _WORLD_SIZE return _WORLD_SIZE # tensor parallel @@ -183,7 +178,7 @@ def get_tensor_parallel_mesh() -> DeviceMesh: global _TENSOR_PARALLEL_MESH if _TENSOR_PARALLEL_MESH is None: - _TENSOR_PARALLEL_MESH = ProcessGroupManager.get_mesh()["tp"] + _TENSOR_PARALLEL_MESH = ProcessGroupManager.get_dense_mesh()["tp"] return _TENSOR_PARALLEL_MESH @staticmethod @@ -268,7 +263,7 @@ def get_pipeline_parallel_mesh() -> DeviceMesh: global _PIPELINE_PARALLEL_MESH if _PIPELINE_PARALLEL_MESH is None: - _PIPELINE_PARALLEL_MESH = ProcessGroupManager.get_mesh()["pp"] + _PIPELINE_PARALLEL_MESH = ProcessGroupManager.get_dense_mesh()["pp"] return _PIPELINE_PARALLEL_MESH @staticmethod @@ -325,7 +320,7 @@ def get_data_parallel_mesh() -> DeviceMesh: global _DATA_PARALLEL_MESH if _DATA_PARALLEL_MESH is None: - _DATA_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp"] + _DATA_PARALLEL_MESH = ProcessGroupManager.get_dense_mesh()["ddp", "fsdp"] return _DATA_PARALLEL_MESH @staticmethod @@ -385,7 +380,7 @@ def set_dummy_data_parallel_world_size(world_size: int): _DATA_PARALLEL_WORLD_SIZE = original_world_size def __str__(self) -> str: - return str(self.get_mesh()) + return str(self.get_dense_mesh()) @staticmethod def destroy_process_groups() -> None: @@ -397,7 +392,6 @@ def destroy_process_groups() -> None: @staticmethod def get_cpu_group() -> ProcessGroup | None: - global _CPU_GROUP return _CPU_GROUP