-
Notifications
You must be signed in to change notification settings - Fork 33
Run parallel distributed tests in CI #832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cf02690
36e4ed0
36b1835
755cb05
682cc79
1c6366f
01e71c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,6 +79,25 @@ def rank(self) -> int: | |
| """ | ||
| return self._distributed.rank | ||
|
|
||
| @property | ||
| def data_parallel_rank(self) -> int: | ||
| """ | ||
| Get the data parallel rank of this process. | ||
|
|
||
| In the context of distributed learning, this is the "batch" | ||
| rank of this process. | ||
| """ | ||
| return self._distributed.data_parallel_rank | ||
|
|
||
| @property | ||
| def total_data_parallel_ranks(self) -> int: | ||
| """ | ||
| Get the total number of data parallel ranks. | ||
|
|
||
| This is the number of parallel splits along the "batch" dimension. | ||
| """ | ||
| return self._distributed.total_data_parallel_ranks | ||
|
|
||
| @property | ||
| def world_size(self) -> int: | ||
| """ | ||
|
|
@@ -118,6 +137,11 @@ def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: | |
| """ | ||
| return self._distributed.reduce_mean(tensor) | ||
|
|
||
| def get_local_slices(self, tensor_shape, rank: int | None = None): | ||
| if rank is None: | ||
| rank = self._distributed.rank | ||
| return self._distributed.get_local_slices(tensor_shape, rank=rank) | ||
|
|
||
| def reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Reduce a tensor representing a sum across all processes. | ||
|
|
@@ -163,6 +187,20 @@ def gather(self, tensor: torch.Tensor) -> list[torch.Tensor] | None: | |
| """ | ||
| return self._distributed.gather(tensor) | ||
|
|
||
| def gather_global(self, tensor: torch.Tensor, global_shape) -> torch.Tensor | None: | ||
| gathered = self.gather(tensor) | ||
| if self.is_root(): | ||
| if gathered is None: | ||
| raise RuntimeError("expected non-none gathered on root rank") | ||
| gathered_global = torch.zeros( | ||
| *global_shape, dtype=tensor.dtype, device=tensor.device | ||
| ) | ||
|
Comment on lines
+195
to
+197
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be a spatial parallelism PR question, do we know how often we are calling |
||
| for i, local in enumerate(gathered): | ||
| gathered_global[self.get_local_slices(global_shape, i)] = local | ||
| return gathered_global | ||
| else: | ||
| return None | ||
|
|
||
| def gather_irregular( | ||
| self, | ||
| tensor: torch.Tensor, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| This directory contains tests which can be run in parallel, for example with torchrun. | ||
|
|
||
| They should also run in serial. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| import torch | ||
|
|
||
| from fme.core import get_device | ||
| from fme.core.distributed import Distributed | ||
|
|
||
|
|
||
| def test_gather_tensor_from_local_slices(): | ||
| """ | ||
| Only tests get_local_slices and gather. | ||
|
|
||
| Because the global data is the same on each rank, there is no coverage of | ||
| "batch" parallelism in this test. | ||
| """ | ||
| dist = Distributed.get_instance() | ||
| global_shape = (4, 4) | ||
| x_global = ( | ||
| torch.arange(global_shape[0] * global_shape[1], device=get_device()).reshape( | ||
| global_shape | ||
| ) | ||
| + 1 | ||
| ) | ||
| x_local = x_global[dist.get_local_slices(global_shape, dist.rank)] | ||
| gathered = dist.gather_global(x_local, global_shape=global_shape) | ||
| if dist.is_root(): | ||
| assert gathered is not None | ||
| torch.testing.assert_close(gathered, x_global) | ||
| else: | ||
| assert gathered is None | ||
|
|
||
|
|
||
| def test_reduce_mean_from_multiple_ranks(): | ||
| """ | ||
| dist.reduce_mean should only reduce along the "data parallel" dimension, not | ||
| along "model parallel" ranks. | ||
| """ | ||
| dist = Distributed.get_instance() | ||
| global_shape = (4, 4) | ||
| x_global_base = torch.arange( | ||
| global_shape[0] * global_shape[1], dtype=torch.float32, device=get_device() | ||
| ).reshape(global_shape) | ||
| # each global/model domain is a reshaped arange, with a different constant offset | ||
| # depending on the batch/data parallel index/rank. | ||
| x_global_ranked = x_global_base + dist.data_parallel_rank | ||
| x_local_ranked = x_global_ranked[dist.get_local_slices(global_shape, dist.rank)] | ||
| x_local_reduced = dist.reduce_mean(x_local_ranked) | ||
|
|
||
| # we expect the offsets to average out, giving the arange map plus an average offset | ||
| x_global_mean_expected = x_global_base + torch.mean( | ||
| torch.arange( | ||
| dist.total_data_parallel_ranks, | ||
| dtype=x_global_base.dtype, | ||
| device=x_global_base.device, | ||
| ) | ||
| ) | ||
| # check the sub-domain we have on the local rank against this expectation | ||
| x_local_reduced_expected = x_global_mean_expected[ | ||
| dist.get_local_slices(global_shape, dist.rank) | ||
| ] | ||
| torch.testing.assert_close(x_local_reduced, x_local_reduced_expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] make the number of processor a Make variable so you can use a different number when calling
make test_parallel