Skip to content

Conversation

@mcgibbon
Copy link
Contributor

This PR adds some non-trivial tests of data and model parallelism.

The implementations are trivial for the non-model-parallel case, but these tests will provide good coverage of the upcoming model-parallel code. This PR makes it so that the infrastructure is in place to use these tests and add to them.

Changes:

  • Added get_local_slices, data_parallel_rank and total_data_parallel_ranks attributes to Distributed and DistributedBackend classes, which are currently used only in unit tests but will be needed for spatial parallelism.

  • Tests added

Comment on lines +195 to +197
gathered_global = torch.zeros(
*global_shape, dtype=tensor.dtype, device=tensor.device
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a spatial parallelism PR question, do we know how often we are calling gather_global? If allocating 2x the global tensor becomes a memory issue, we may want to consider passing a pre-allocated buffer to the gather call and use it in-place, that way we don't hold both the temporary gathered and gathered_global at the same time. I don't know if this is an issue, but something we might want to consider in the future.

pytest --durations 40 .

test_parallel:
torchrun --nproc-per-node 2 -m pytest ./fme/core/distributed/parallel_tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] make the number of processor a Make variable so you can use a different number when calling make test_parallel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants