Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
run: |
source .venv/bin/activate
make test_cov
make test_parallel
cpu-very-fast-no-healpix:
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -91,3 +92,4 @@ jobs:
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH}
python3 fme/require_gpu.py
make test
make test_parallel
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ create_environment:
test:
pytest --durations 40 .

test_parallel:
torchrun --nproc-per-node 2 -m pytest ./fme/core/distributed/parallel_tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] make the number of processor a Make variable so you can use a different number when calling make test_parallel


# --cov must come after pytest args to use the sources defined by config
test_cov:
pytest --durations 40 --cov --cov-report=term-missing:skip-covered --cov-config=pyproject.toml .
Expand Down
17 changes: 17 additions & 0 deletions fme/core/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,32 @@ def rank(self) -> int:
"""Global rank of this process."""
...

@property
@abstractmethod
def data_parallel_rank(self) -> int: ...

@property
@abstractmethod
def total_ranks(self) -> int:
"""Total number of processes."""
...

@property
@abstractmethod
def total_data_parallel_ranks(self) -> int:
"""
Total number of rank splits along the data parallel dimension.

For example, 8 ranks using 2 ranks of model parallelism would have
only 4 ranks of data paralellism.
"""

@abstractmethod
def local_batch_size(self, batch_size: int) -> int: ...

@abstractmethod
def get_local_slices(self, tensor_shape, rank: int): ...

@abstractmethod
def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor | None: ...

Expand Down
38 changes: 38 additions & 0 deletions fme/core/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,25 @@ def rank(self) -> int:
"""
return self._distributed.rank

@property
def data_parallel_rank(self) -> int:
"""
Get the data parallel rank of this process.

In the context of distributed learning, this is the "batch"
rank of this process.
"""
return self._distributed.data_parallel_rank

@property
def total_data_parallel_ranks(self) -> int:
"""
Get the total number of data parallel ranks.

This is the number of parallel splits along the "batch" dimension.
"""
return self._distributed.total_data_parallel_ranks

@property
def world_size(self) -> int:
"""
Expand Down Expand Up @@ -118,6 +137,11 @@ def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor:
"""
return self._distributed.reduce_mean(tensor)

def get_local_slices(self, tensor_shape, rank: int | None = None):
if rank is None:
rank = self._distributed.rank
return self._distributed.get_local_slices(tensor_shape, rank=rank)

def reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Reduce a tensor representing a sum across all processes.
Expand Down Expand Up @@ -163,6 +187,20 @@ def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None:
"""
return self._distributed.gather(tensor)

def gather_global(self, tensor: torch.Tensor, global_shape) -> torch.Tensor | None:
gathered = self.gather(tensor)
if self.is_root():
if gathered is None:
raise RuntimeError("expected non-none gathered on root rank")
gathered_global = torch.zeros(
*global_shape, dtype=tensor.dtype, device=tensor.device
)
Comment on lines +195 to +197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a spatial parallelism PR question, do we know how often we are calling gather_global? If allocating 2x the global tensor becomes a memory issue, we may want to consider passing a pre-allocated buffer to the gather call and use it in-place, that way we don't hold both the temporary gathered and gathered_global at the same time. I don't know if this is an issue, but something we might want to consider in the future.

for i, local in enumerate(gathered):
gathered_global[self.get_local_slices(global_shape, i)] = local
return gathered_global
else:
return None

def gather_irregular(
self,
tensor: torch.Tensor,
Expand Down
11 changes: 11 additions & 0 deletions fme/core/distributed/non_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,22 @@ def rank(self) -> int:
"""Global rank of this process."""
return 0

@property
def data_parallel_rank(self) -> int:
return self.rank # no model parallelism

@property
def total_ranks(self) -> int:
"""Total number of processes."""
return 1

@property
def total_data_parallel_ranks(self) -> int:
return self.total_ranks # no model parallelism

def get_local_slices(self, tensor_shape, rank: int):
return tuple(slice(None, None) for _ in tensor_shape)

def local_batch_size(self, batch_size: int) -> int:
return batch_size

Expand Down
3 changes: 3 additions & 0 deletions fme/core/distributed/parallel_tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
This directory contains tests which can be run in parallel, for example with torchrun.

They should also run in serial.
59 changes: 59 additions & 0 deletions fme/core/distributed/parallel_tests/test_local_slices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch

from fme.core import get_device
from fme.core.distributed import Distributed


def test_gather_tensor_from_local_slices():
"""
Only tests get_local_slices and gather.

Because the global data is the same on each rank, there is no coverage of
"batch" parallelism in this test.
"""
dist = Distributed.get_instance()
global_shape = (4, 4)
x_global = (
torch.arange(global_shape[0] * global_shape[1], device=get_device()).reshape(
global_shape
)
+ 1
)
x_local = x_global[dist.get_local_slices(global_shape, dist.rank)]
gathered = dist.gather_global(x_local, global_shape=global_shape)
if dist.is_root():
assert gathered is not None
torch.testing.assert_close(gathered, x_global)
else:
assert gathered is None


def test_reduce_mean_from_multiple_ranks():
"""
dist.reduce_mean should only reduce along the "data parallel" dimension, not
along "model parallel" ranks.
"""
dist = Distributed.get_instance()
global_shape = (4, 4)
x_global_base = torch.arange(
global_shape[0] * global_shape[1], dtype=torch.float32, device=get_device()
).reshape(global_shape)
# each global/model domain is a reshaped arange, with a different constant offset
# depending on the batch/data parallel index/rank.
x_global_ranked = x_global_base + dist.data_parallel_rank
x_local_ranked = x_global_ranked[dist.get_local_slices(global_shape, dist.rank)]
x_local_reduced = dist.reduce_mean(x_local_ranked)

# we expect the offsets to average out, giving the arange map plus an average offset
x_global_mean_expected = x_global_base + torch.mean(
torch.arange(
dist.total_data_parallel_ranks,
dtype=x_global_base.dtype,
device=x_global_base.device,
)
)
# check the sub-domain we have on the local rank against this expectation
x_local_reduced_expected = x_global_mean_expected[
dist.get_local_slices(global_shape, dist.rank)
]
torch.testing.assert_close(x_local_reduced, x_local_reduced_expected)
11 changes: 11 additions & 0 deletions fme/core/distributed/torch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,22 @@ def rank(self) -> int:
"""Global rank of this process."""
return self._rank

@property
def data_parallel_rank(self) -> int:
return self.rank # no model parallelism

@property
def total_ranks(self) -> int:
"""Total number of processes."""
return self.world_size

@property
def total_data_parallel_ranks(self) -> int:
return self.total_ranks # no model parallelism

def get_local_slices(self, tensor_shape, rank: int):
return tuple(slice(None, None) for _ in tensor_shape)

def local_batch_size(self, batch_size: int) -> int:
return batch_size // self.total_ranks

Expand Down