diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2813412c3..ce431815b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -34,6 +34,7 @@ jobs: run: | source .venv/bin/activate make test_cov + make test_parallel cpu-very-fast-no-healpix: runs-on: ubuntu-latest steps: @@ -91,3 +92,4 @@ jobs: export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} python3 fme/require_gpu.py make test + make test_parallel diff --git a/Makefile b/Makefile index 4c9874125..4d8d8cdbb 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ ENVIRONMENT_NAME ?= fme USERNAME ?= $(shell beaker account whoami --format=json | jq -r '.[0].name') DEPLOY_TARGET ?= pypi BEAKER_WORKSPACE = ai2/ace +NPROC ?= 2 ifeq ($(shell uname), Linux) CONDA_PACKAGES=gxx_linux-64 pip @@ -53,6 +54,9 @@ create_environment: test: pytest --durations 40 . +test_parallel: + torchrun --nproc-per-node $(NPROC) -m pytest ./fme/core/distributed/parallel_tests + # --cov must come after pytest args to use the sources defined by config test_cov: pytest --durations 40 --cov --cov-report=term-missing:skip-covered --cov-config=pyproject.toml . diff --git a/fme/core/distributed/base.py b/fme/core/distributed/base.py index 254504b0f..95ca09859 100644 --- a/fme/core/distributed/base.py +++ b/fme/core/distributed/base.py @@ -16,15 +16,34 @@ def rank(self) -> int: """Global rank of this process.""" ... + @property + @abstractmethod + def data_parallel_rank(self) -> int: ... + @property @abstractmethod def total_ranks(self) -> int: """Total number of processes.""" ... + @property + @abstractmethod + def total_data_parallel_ranks(self) -> int: + """ + Total number of rank splits along the data parallel dimension. + + For example, 8 ranks using 2 ranks of model parallelism would have + only 4 ranks of data paralellism. + """ + @abstractmethod def local_batch_size(self, batch_size: int) -> int: ... + @abstractmethod + def get_local_slices( + self, tensor_shape, rank: int, data_parallel_dim: int | None + ): ... + @abstractmethod def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor | None: ... @@ -38,7 +57,9 @@ def reduce_min(self, tensor: torch.Tensor) -> torch.Tensor | None: ... def reduce_max(self, tensor: torch.Tensor) -> torch.Tensor | None: ... @abstractmethod - def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: + def gather( + self, tensor: torch.Tensor, gather_list: list[torch.Tensor] | None + ) -> list[torch.Tensor] | None: """ Gather a tensor from all processes to the root process. @@ -49,6 +70,8 @@ def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: Args: tensor: The tensor to gather. + gather_list: A list of tensor buffers to gather into, + one for each rank. Returns: A list of tensors, where the i-th element is the tensor diff --git a/fme/core/distributed/distributed.py b/fme/core/distributed/distributed.py index 6601cc07d..030503b03 100644 --- a/fme/core/distributed/distributed.py +++ b/fme/core/distributed/distributed.py @@ -79,6 +79,25 @@ def rank(self) -> int: """ return self._distributed.rank + @property + def data_parallel_rank(self) -> int: + """ + Get the data parallel rank of this process. + + In the context of distributed learning, this is the "batch" + rank of this process. + """ + return self._distributed.data_parallel_rank + + @property + def total_data_parallel_ranks(self) -> int: + """ + Get the total number of data parallel ranks. + + This is the number of parallel splits along the "batch" dimension. + """ + return self._distributed.total_data_parallel_ranks + @property def world_size(self) -> int: """ @@ -118,6 +137,36 @@ def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: """ return self._distributed.reduce_mean(tensor) + def get_local_slices( + self, + tensor_shape, + rank: int | None = None, + data_parallel_dim: int | None = None, + ): + """ + Gets the slice corresponding to the current rank within a global tensor_shape. + + Args: + tensor_shape: the shape of the global tensor, which may or may not contain + a data parallel (batch) dimension. + rank: the rank to retrieve the slice for, defaults to the current rank. + data_parallel_dim: the index of the data parallel dimension, if it exists. + by default, assumes the tensor does not have a data parallel dimension. + """ + if data_parallel_dim is not None and ( + tensor_shape[data_parallel_dim] % self.total_data_parallel_ranks != 0 + ): + raise ValueError( + "expected global data parallel dim to be divisible by data parallel " + f"ranks, got global shape {tensor_shape} with " + f"{self.total_data_parallel_ranks} data parallel ranks" + ) + if rank is None: + rank = self._distributed.rank + return self._distributed.get_local_slices( + tensor_shape, rank=rank, data_parallel_dim=data_parallel_dim + ) + def reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: """ Reduce a tensor representing a sum across all processes. @@ -145,7 +194,9 @@ def reduce_max(self, tensor: torch.Tensor) -> torch.Tensor: """ return self._distributed.reduce_max(tensor) - def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: + def gather( + self, tensor: torch.Tensor, gather_list: list[torch.Tensor] | None = None + ) -> list[torch.Tensor] | None: """ Gather a tensor from all processes to the root process. @@ -156,12 +207,51 @@ def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: Args: tensor: The tensor to gather. + gather_list: A list of tensor buffers to gather into, + one for each rank. Returns: A list of tensors, where the i-th element is the tensor from the i-th process. """ - return self._distributed.gather(tensor) + return self._distributed.gather(tensor, gather_list=gather_list) + + def gather_global( + self, tensor: torch.Tensor, global_shape, data_parallel_dim: int = 0 + ) -> torch.Tensor | None: + """ + Gathers tensor data into a single tensor with the data from all ranks. + + Args: + tensor: the tensor data to gather + global_shape: the shape of the tensor containing data from all ranks + data_parallel_dim: the dimension in global_shape corresponding to the + data parallel (or "batch") dimension + """ + if global_shape[data_parallel_dim] % self.total_data_parallel_ranks != 0: + raise ValueError( + "expected global data parallel dim to be divisible by data parallel " + f"ranks, got global_shape {global_shape} with " + f"{self.total_data_parallel_ranks} data parallel ranks" + ) + if self.is_root(): + gathered_global = torch.zeros( + *global_shape, dtype=tensor.dtype, device=tensor.device + ) + gather_list = [] + for i in range(self.total_data_parallel_ranks): + gather_list.append( + gathered_global[ + self.get_local_slices( + global_shape, i, data_parallel_dim=data_parallel_dim + ) + ] + ) + else: + gather_list = None + gathered_global = None + self.gather(tensor, gather_list=gather_list) + return gathered_global def gather_irregular( self, diff --git a/fme/core/distributed/non_distributed.py b/fme/core/distributed/non_distributed.py index 428eb15de..ca39ff2f5 100644 --- a/fme/core/distributed/non_distributed.py +++ b/fme/core/distributed/non_distributed.py @@ -27,11 +27,22 @@ def rank(self) -> int: """Global rank of this process.""" return 0 + @property + def data_parallel_rank(self) -> int: + return self.rank # no model parallelism + @property def total_ranks(self) -> int: """Total number of processes.""" return 1 + @property + def total_data_parallel_ranks(self) -> int: + return self.total_ranks # no model parallelism + + def get_local_slices(self, tensor_shape, rank: int, data_parallel_dim: int | None): + return tuple(slice(None, None) for _ in tensor_shape) + def local_batch_size(self, batch_size: int) -> int: return batch_size @@ -48,7 +59,16 @@ def reduce_min(self, tensor: torch.Tensor) -> torch.Tensor: def reduce_max(self, tensor: torch.Tensor) -> torch.Tensor: return tensor - def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: + def gather( + self, tensor: torch.Tensor, gather_list: list[torch.Tensor] | None + ) -> list[torch.Tensor] | None: + if gather_list is not None: + if len(gather_list) != 1: + raise ValueError( + f"expected 1 element in gather_list, got {len(gather_list)}" + ) + gather_list[0][:] = tensor + return gather_list return [tensor] def gather_irregular(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: diff --git a/fme/core/distributed/parallel_tests/README.md b/fme/core/distributed/parallel_tests/README.md new file mode 100644 index 000000000..a8548762e --- /dev/null +++ b/fme/core/distributed/parallel_tests/README.md @@ -0,0 +1,3 @@ +This directory contains tests which can be run in parallel, for example with torchrun. + +They should also run in serial. diff --git a/fme/core/distributed/parallel_tests/test_local_slices.py b/fme/core/distributed/parallel_tests/test_local_slices.py new file mode 100644 index 000000000..1a5db5417 --- /dev/null +++ b/fme/core/distributed/parallel_tests/test_local_slices.py @@ -0,0 +1,87 @@ +import numpy as np +import torch + +from fme.core import get_device +from fme.core.distributed import Distributed + + +def test_gather_tensor_from_local_slices(): + """ + Only tests get_local_slices and gather. + + Because the global data is the same on each rank, there is no coverage of + "batch" parallelism in this test. + """ + dist = Distributed.get_instance() + global_shape = (2, 4, 4) + x_global = ( + torch.arange(np.prod(global_shape), device=get_device()).reshape(global_shape) + + 1 + ) + x_local = x_global[ + dist.get_local_slices(global_shape, dist.rank, data_parallel_dim=0) + ] + gathered = dist.gather_global( + x_local, global_shape=global_shape, data_parallel_dim=0 + ) + if dist.is_root(): + assert gathered is not None + torch.testing.assert_close(gathered, x_global) + else: + assert gathered is None + + +def test_local_slices_subdivide_domain(): + """ + Only tests get_local_slices and gather. + + Because the global data is the same on each rank, there is no coverage of + "batch" parallelism in this test. + """ + dist = Distributed.get_instance() + global_shape = (2, 4, 4) + x_global = torch.zeros(global_shape, device=get_device()) + total_size = np.prod(global_shape) + assert ( + total_size % dist.world_size == 0 + ), "total_size is not divisible by total ranks" + expected_slice_size = total_size // dist.world_size + for i in range(dist.world_size): + local_slices = dist.get_local_slices(global_shape, i, data_parallel_dim=0) + # the slices should be of the minimum size required + assert x_global[local_slices].nelement() == expected_slice_size + x_global[local_slices] = 1 + torch.testing.assert_close( + x_global, torch.ones_like(x_global) + ) # the entire domain should get selected + + +def test_reduce_mean_from_multiple_ranks(): + """ + dist.reduce_mean should only reduce along the "data parallel" dimension, not + along "model parallel" ranks. + """ + dist = Distributed.get_instance() + global_shape = (4, 4) + x_global_base = torch.arange( + global_shape[0] * global_shape[1], dtype=torch.float32, device=get_device() + ).reshape(global_shape) + # each global/model domain is a reshaped arange, with a different constant offset + # depending on the batch/data parallel index/rank. + x_global_ranked = x_global_base + dist.data_parallel_rank + x_local_ranked = x_global_ranked[dist.get_local_slices(global_shape, dist.rank)] + x_local_reduced = dist.reduce_mean(x_local_ranked) + + # we expect the offsets to average out, giving the arange map plus an average offset + x_global_mean_expected = x_global_base + torch.mean( + torch.arange( + dist.total_data_parallel_ranks, + dtype=x_global_base.dtype, + device=x_global_base.device, + ) + ) + # check the sub-domain we have on the local rank against this expectation + x_local_reduced_expected = x_global_mean_expected[ + dist.get_local_slices(global_shape, dist.rank) + ] + torch.testing.assert_close(x_local_reduced, x_local_reduced_expected) diff --git a/fme/core/distributed/torch_distributed.py b/fme/core/distributed/torch_distributed.py index e040e81e2..256c5b490 100644 --- a/fme/core/distributed/torch_distributed.py +++ b/fme/core/distributed/torch_distributed.py @@ -69,11 +69,34 @@ def rank(self) -> int: """Global rank of this process.""" return self._rank + @property + def data_parallel_rank(self) -> int: + return self.rank # no model parallelism + @property def total_ranks(self) -> int: """Total number of processes.""" return self.world_size + @property + def total_data_parallel_ranks(self) -> int: + return self.total_ranks # no model parallelism + + def get_local_slices(self, tensor_shape, rank: int, data_parallel_dim: int | None): + return_list = [slice(None, None) for _ in tensor_shape] + if data_parallel_dim is not None: + if tensor_shape[data_parallel_dim] % self.total_data_parallel_ranks != 0: + raise ValueError( + "expected global data parallel dim to be divisible by data " + f"parallel ranks, got global shape {tensor_shape} with " + f"{self.total_data_parallel_ranks} data parallel ranks" + ) + per_rank = tensor_shape[data_parallel_dim] // self.total_data_parallel_ranks + return_list[data_parallel_dim] = slice( + rank * per_rank, (rank + 1) * per_rank + ) + return tuple(return_list) + def local_batch_size(self, batch_size: int) -> int: return batch_size // self.total_ranks @@ -93,9 +116,10 @@ def reduce_max(self, tensor: torch.Tensor) -> torch.Tensor | None: torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MAX) return tensor - def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: - gather_list: list[torch.Tensor] | None = None - if self.rank == 0: + def gather( + self, tensor: torch.Tensor, gather_list: list[torch.Tensor] | None = None + ) -> list[torch.Tensor] | None: + if gather_list is None and self.rank == 0: gather_list = [tensor] + [ torch.empty_like(tensor) for _ in range(self.world_size - 1) ]