Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions lm_engine/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

from .accelerator import Accelerator
from .miscellaneous import divide_if_divisible
from .packages import is_torch_xla_available

Expand All @@ -26,7 +27,7 @@


# general
_MESH: DeviceMesh | None = None
_DENSE_MESH: DeviceMesh | None = None
_GLOBAL_RANK: int | None = None
_LOCAL_RANK: int | None = None
_WORLD_SIZE: int | None = None
Expand Down Expand Up @@ -65,9 +66,7 @@ def __init__(
timeout_minutes: int | None = None,
use_async_tensor_parallel: bool = False,
) -> ProcessGroupManager:
from .accelerator import Accelerator

global _MESH
global _DENSE_MESH
global _TENSOR_PARALLEL_FIRST_RANK
global _DATA_PARALLEL_REPLICATION_WORLD_SIZE
global _DATA_PARALLEL_SHARDING_WORLD_SIZE
Expand Down Expand Up @@ -129,7 +128,7 @@ def __init__(
_DATA_PARALLEL_SHARDING_WORLD_SIZE = data_parallel_sharding_world_size

# FIXME unable to use XLA mesh since XLA mesh doesn't support accessing submesh
_MESH = init_device_mesh(
_DENSE_MESH = init_device_mesh(
"cpu" if accelerator == Accelerator.tpu else Accelerator.get_device_type(),
(
pipeline_parallel_world_size,
Expand Down Expand Up @@ -158,23 +157,19 @@ def is_initialized() -> bool:
return torch.distributed.is_initialized()

@staticmethod
def get_mesh() -> DeviceMesh:
global _MESH
return _MESH
def get_dense_mesh() -> DeviceMesh:
return _DENSE_MESH

@staticmethod
def get_global_rank() -> int:
global _GLOBAL_RANK
return _GLOBAL_RANK

@staticmethod
def get_local_rank() -> int:
global _LOCAL_RANK
return _LOCAL_RANK

@staticmethod
def get_world_size() -> int:
global _WORLD_SIZE
return _WORLD_SIZE

# tensor parallel
Expand All @@ -183,7 +178,7 @@ def get_tensor_parallel_mesh() -> DeviceMesh:
global _TENSOR_PARALLEL_MESH

if _TENSOR_PARALLEL_MESH is None:
_TENSOR_PARALLEL_MESH = ProcessGroupManager.get_mesh()["tp"]
_TENSOR_PARALLEL_MESH = ProcessGroupManager.get_dense_mesh()["tp"]
return _TENSOR_PARALLEL_MESH

@staticmethod
Expand Down Expand Up @@ -268,7 +263,7 @@ def get_pipeline_parallel_mesh() -> DeviceMesh:
global _PIPELINE_PARALLEL_MESH

if _PIPELINE_PARALLEL_MESH is None:
_PIPELINE_PARALLEL_MESH = ProcessGroupManager.get_mesh()["pp"]
_PIPELINE_PARALLEL_MESH = ProcessGroupManager.get_dense_mesh()["pp"]
return _PIPELINE_PARALLEL_MESH

@staticmethod
Expand Down Expand Up @@ -325,7 +320,7 @@ def get_data_parallel_mesh() -> DeviceMesh:
global _DATA_PARALLEL_MESH

if _DATA_PARALLEL_MESH is None:
_DATA_PARALLEL_MESH = ProcessGroupManager.get_mesh()["ddp", "fsdp"]
_DATA_PARALLEL_MESH = ProcessGroupManager.get_dense_mesh()["ddp", "fsdp"]
return _DATA_PARALLEL_MESH

@staticmethod
Expand Down Expand Up @@ -385,7 +380,7 @@ def set_dummy_data_parallel_world_size(world_size: int):
_DATA_PARALLEL_WORLD_SIZE = original_world_size

def __str__(self) -> str:
return str(self.get_mesh())
return str(self.get_dense_mesh())

@staticmethod
def destroy_process_groups() -> None:
Expand All @@ -397,7 +392,6 @@ def destroy_process_groups() -> None:

@staticmethod
def get_cpu_group() -> ProcessGroup | None:
global _CPU_GROUP
return _CPU_GROUP


Expand Down