From cf026903d83b4fc385737b0235bd0e5faf8d7410 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 12 Feb 2026 21:33:29 +0000 Subject: [PATCH 1/5] add get_local_slices and test --- fme/core/distributed/base.py | 3 +++ fme/core/distributed/distributed.py | 5 +++++ fme/core/distributed/non_distributed.py | 3 +++ fme/core/distributed/parallel_tests/README.md | 3 +++ .../parallel_tests/test_local_slices.py | 18 ++++++++++++++++++ fme/core/distributed/torch_distributed.py | 3 +++ 6 files changed, 35 insertions(+) create mode 100644 fme/core/distributed/parallel_tests/README.md create mode 100644 fme/core/distributed/parallel_tests/test_local_slices.py diff --git a/fme/core/distributed/base.py b/fme/core/distributed/base.py index 254504b0f..6ad34a1c4 100644 --- a/fme/core/distributed/base.py +++ b/fme/core/distributed/base.py @@ -25,6 +25,9 @@ def total_ranks(self) -> int: @abstractmethod def local_batch_size(self, batch_size: int) -> int: ... + @abstractmethod + def get_local_slices(self, tensor_shape, rank: int): ... + @abstractmethod def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor | None: ... diff --git a/fme/core/distributed/distributed.py b/fme/core/distributed/distributed.py index 6601cc07d..7cabdf9c0 100644 --- a/fme/core/distributed/distributed.py +++ b/fme/core/distributed/distributed.py @@ -118,6 +118,11 @@ def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: """ return self._distributed.reduce_mean(tensor) + def get_local_slices(self, tensor_shape, rank: int | None = None): + if rank is None: + rank = self._distributed.rank + return self._distributed.get_local_slices(tensor_shape, rank=rank) + def reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: """ Reduce a tensor representing a sum across all processes. diff --git a/fme/core/distributed/non_distributed.py b/fme/core/distributed/non_distributed.py index 428eb15de..b9dbcbc08 100644 --- a/fme/core/distributed/non_distributed.py +++ b/fme/core/distributed/non_distributed.py @@ -32,6 +32,9 @@ def total_ranks(self) -> int: """Total number of processes.""" return 1 + def get_local_slices(self, tensor_shape, rank: int): + return tuple(slice(None, None) for _ in tensor_shape) + def local_batch_size(self, batch_size: int) -> int: return batch_size diff --git a/fme/core/distributed/parallel_tests/README.md b/fme/core/distributed/parallel_tests/README.md new file mode 100644 index 000000000..a8548762e --- /dev/null +++ b/fme/core/distributed/parallel_tests/README.md @@ -0,0 +1,3 @@ +This directory contains tests which can be run in parallel, for example with torchrun. + +They should also run in serial. diff --git a/fme/core/distributed/parallel_tests/test_local_slices.py b/fme/core/distributed/parallel_tests/test_local_slices.py new file mode 100644 index 000000000..5dc3131af --- /dev/null +++ b/fme/core/distributed/parallel_tests/test_local_slices.py @@ -0,0 +1,18 @@ +import torch + +from fme.core.distributed import Distributed + + +def test_gather_tensor_from_local_slices(): + dist = Distributed.get_instance() + global_shape = (4, 4) + x_global = torch.arange(global_shape[0] * global_shape[1]).reshape(global_shape) + 1 + x_local = x_global[dist.get_local_slices(global_shape, dist.rank)] + gathered = dist.gather(x_local) + if dist.is_root(): + if gathered is None: + raise RuntimeError("expected non-none gathered on root rank") + gathered_global = torch.zeros_like(x_global) + for i, local in enumerate(gathered): + gathered_global[dist.get_local_slices(global_shape, i)] = local + torch.testing.assert_close(gathered_global, x_global) diff --git a/fme/core/distributed/torch_distributed.py b/fme/core/distributed/torch_distributed.py index e040e81e2..0b380f998 100644 --- a/fme/core/distributed/torch_distributed.py +++ b/fme/core/distributed/torch_distributed.py @@ -74,6 +74,9 @@ def total_ranks(self) -> int: """Total number of processes.""" return self.world_size + def get_local_slices(self, tensor_shape, rank: int): + return tuple(slice(None, None) for _ in tensor_shape) + def local_batch_size(self, batch_size: int) -> int: return batch_size // self.total_ranks From 36e4ed0328ddd1de020d283113922e458b8d1c65 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 12 Feb 2026 21:42:09 +0000 Subject: [PATCH 2/5] move gather onto Distributed --- fme/core/distributed/distributed.py | 14 +++++++++++++ .../distributed/parallel_tests/__init__.py | 0 .../parallel_tests/test_local_slices.py | 20 +++++++++++-------- 3 files changed, 26 insertions(+), 8 deletions(-) create mode 100644 fme/core/distributed/parallel_tests/__init__.py diff --git a/fme/core/distributed/distributed.py b/fme/core/distributed/distributed.py index 7cabdf9c0..3b73f2d24 100644 --- a/fme/core/distributed/distributed.py +++ b/fme/core/distributed/distributed.py @@ -168,6 +168,20 @@ def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: """ return self._distributed.gather(tensor) + def gather_global(self, tensor: torch.Tensor, global_shape) -> torch.Tensor | None: + gathered = self.gather(tensor) + if self.is_root(): + if gathered is None: + raise RuntimeError("expected non-none gathered on root rank") + gathered_global = torch.zeros( + *global_shape, dtype=tensor.dtype, device=tensor.device + ) + for i, local in enumerate(gathered): + gathered_global[self.get_local_slices(global_shape, i)] = local + return gathered_global + else: + return None + def gather_irregular( self, tensor: torch.Tensor, diff --git a/fme/core/distributed/parallel_tests/__init__.py b/fme/core/distributed/parallel_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fme/core/distributed/parallel_tests/test_local_slices.py b/fme/core/distributed/parallel_tests/test_local_slices.py index 5dc3131af..adee7616a 100644 --- a/fme/core/distributed/parallel_tests/test_local_slices.py +++ b/fme/core/distributed/parallel_tests/test_local_slices.py @@ -1,18 +1,22 @@ import torch +from fme.core import get_device from fme.core.distributed import Distributed def test_gather_tensor_from_local_slices(): dist = Distributed.get_instance() global_shape = (4, 4) - x_global = torch.arange(global_shape[0] * global_shape[1]).reshape(global_shape) + 1 + x_global = ( + torch.arange(global_shape[0] * global_shape[1], device=get_device()).reshape( + global_shape + ) + + 1 + ) x_local = x_global[dist.get_local_slices(global_shape, dist.rank)] - gathered = dist.gather(x_local) + gathered = dist.gather_global(x_local, global_shape=global_shape) if dist.is_root(): - if gathered is None: - raise RuntimeError("expected non-none gathered on root rank") - gathered_global = torch.zeros_like(x_global) - for i, local in enumerate(gathered): - gathered_global[dist.get_local_slices(global_shape, i)] = local - torch.testing.assert_close(gathered_global, x_global) + assert gathered is not None + torch.testing.assert_close(gathered, x_global) + else: + assert gathered is None From 36b1835e22403712e61b7e00083d43c9b19c8f63 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 12 Feb 2026 22:07:01 +0000 Subject: [PATCH 3/5] add test of reduce_mean --- fme/core/distributed/base.py | 14 +++++++ fme/core/distributed/distributed.py | 19 ++++++++++ fme/core/distributed/non_distributed.py | 8 ++++ .../parallel_tests/test_local_slices.py | 37 +++++++++++++++++++ fme/core/distributed/torch_distributed.py | 8 ++++ 5 files changed, 86 insertions(+) diff --git a/fme/core/distributed/base.py b/fme/core/distributed/base.py index 6ad34a1c4..d9656ec41 100644 --- a/fme/core/distributed/base.py +++ b/fme/core/distributed/base.py @@ -16,12 +16,26 @@ def rank(self) -> int: """Global rank of this process.""" ... + @property + @abstractmethod + def data_parallel_rank(self) -> int: ... + @property @abstractmethod def total_ranks(self) -> int: """Total number of processes.""" ... + @property + @abstractmethod + def total_data_parallel_ranks(self) -> int: + """ + Total number of rank splits along the data parallel dimension. + + For example, 8 ranks using 2 ranks of model parallelism would have + only 4 ranks of data paralellism. + """ + @abstractmethod def local_batch_size(self, batch_size: int) -> int: ... diff --git a/fme/core/distributed/distributed.py b/fme/core/distributed/distributed.py index 3b73f2d24..f511b70ac 100644 --- a/fme/core/distributed/distributed.py +++ b/fme/core/distributed/distributed.py @@ -79,6 +79,25 @@ def rank(self) -> int: """ return self._distributed.rank + @property + def data_parallel_rank(self) -> int: + """ + Get the data parallel rank of this process. + + In the context of distributed learning, this is the "batch" + rank of this process. + """ + return self._distributed.data_parallel_rank + + @property + def total_data_parallel_ranks(self) -> int: + """ + Get the total number of data parallel ranks. + + This is the number of parallel splits along the "batch" dimension. + """ + return self._distributed.total_data_parallel_ranks + @property def world_size(self) -> int: """ diff --git a/fme/core/distributed/non_distributed.py b/fme/core/distributed/non_distributed.py index b9dbcbc08..ec70f211d 100644 --- a/fme/core/distributed/non_distributed.py +++ b/fme/core/distributed/non_distributed.py @@ -27,11 +27,19 @@ def rank(self) -> int: """Global rank of this process.""" return 0 + @property + def data_parallel_rank(self) -> int: + return self.rank # no model parallelism + @property def total_ranks(self) -> int: """Total number of processes.""" return 1 + @property + def total_data_parallel_ranks(self) -> int: + return self.total_ranks # no model parallelism + def get_local_slices(self, tensor_shape, rank: int): return tuple(slice(None, None) for _ in tensor_shape) diff --git a/fme/core/distributed/parallel_tests/test_local_slices.py b/fme/core/distributed/parallel_tests/test_local_slices.py index adee7616a..b5623b1be 100644 --- a/fme/core/distributed/parallel_tests/test_local_slices.py +++ b/fme/core/distributed/parallel_tests/test_local_slices.py @@ -5,6 +5,12 @@ def test_gather_tensor_from_local_slices(): + """ + Only tests get_local_slices and gather. + + Because the global data is the same on each rank, there is no coverage of + "batch" parallelism in this test. + """ dist = Distributed.get_instance() global_shape = (4, 4) x_global = ( @@ -20,3 +26,34 @@ def test_gather_tensor_from_local_slices(): torch.testing.assert_close(gathered, x_global) else: assert gathered is None + + +def test_reduce_mean_from_multiple_ranks(): + """ + dist.reduce_mean should only reduce along the "data parallel" dimension, not + along "model parallel" ranks. + """ + dist = Distributed.get_instance() + global_shape = (4, 4) + x_global_base = torch.arange( + global_shape[0] * global_shape[1], dtype=torch.float32, device=get_device() + ).reshape(global_shape) + # each global/model domain is a reshaped arange, with a different constant offset + # depending on the batch/data parallel index/rank. + x_global_ranked = x_global_base + dist.data_parallel_rank + x_local_ranked = x_global_ranked[dist.get_local_slices(global_shape, dist.rank)] + x_local_reduced = dist.reduce_mean(x_local_ranked) + + # we expect the offsets to average out, giving the arange map plus an average offset + x_global_mean_expected = x_global_base + torch.mean( + torch.arange( + dist.total_data_parallel_ranks, + dtype=x_global_base.dtype, + device=x_global_base.device, + ) + ) + # check the sub-domain we have on the local rank against this expectation + x_local_reduced_expected = x_global_mean_expected[ + dist.get_local_slices(global_shape, dist.rank) + ] + torch.testing.assert_close(x_local_reduced, x_local_reduced_expected) diff --git a/fme/core/distributed/torch_distributed.py b/fme/core/distributed/torch_distributed.py index 0b380f998..e5ab1ec35 100644 --- a/fme/core/distributed/torch_distributed.py +++ b/fme/core/distributed/torch_distributed.py @@ -69,11 +69,19 @@ def rank(self) -> int: """Global rank of this process.""" return self._rank + @property + def data_parallel_rank(self) -> int: + return self.rank # no model parallelism + @property def total_ranks(self) -> int: """Total number of processes.""" return self.world_size + @property + def total_data_parallel_ranks(self) -> int: + return self.total_ranks # no model parallelism + def get_local_slices(self, tensor_shape, rank: int): return tuple(slice(None, None) for _ in tensor_shape) From 755cb05c3a3ca0aa33bdb90bafdb3356f92c314b Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 12 Feb 2026 22:09:58 +0000 Subject: [PATCH 4/5] run parallel tests in CI --- .github/workflows/tests.yaml | 2 ++ Makefile | 3 +++ 2 files changed, 5 insertions(+) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2813412c3..ce431815b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -34,6 +34,7 @@ jobs: run: | source .venv/bin/activate make test_cov + make test_parallel cpu-very-fast-no-healpix: runs-on: ubuntu-latest steps: @@ -91,3 +92,4 @@ jobs: export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH} python3 fme/require_gpu.py make test + make test_parallel diff --git a/Makefile b/Makefile index 4c9874125..1646c8e12 100644 --- a/Makefile +++ b/Makefile @@ -53,6 +53,9 @@ create_environment: test: pytest --durations 40 . +test_parallel: + torchrun --nproc-per-node 2 -m pytest ./fme/core/distributed/parallel_tests + # --cov must come after pytest args to use the sources defined by config test_cov: pytest --durations 40 --cov --cov-report=term-missing:skip-covered --cov-config=pyproject.toml . From 1c6366f3c3e887cfad4fbf44180f5348a77e49ea Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 12 Feb 2026 22:29:44 +0000 Subject: [PATCH 5/5] remove unnecessary __init__.py --- fme/core/distributed/parallel_tests/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 fme/core/distributed/parallel_tests/__init__.py diff --git a/fme/core/distributed/parallel_tests/__init__.py b/fme/core/distributed/parallel_tests/__init__.py deleted file mode 100644 index e69de29bb..000000000