-
Notifications
You must be signed in to change notification settings - Fork 386
Add support for MXFP8 All gather #3435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+201
−0
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
b144632
add MXFP8 all gather support
avizon-aws 3c0e6ed
added TODO for future feature
avizon-aws ef6ed8f
remove emoji from comment
avizon-aws f96d168
fixed ruff formating
avizon-aws 094b01c
fixed ruff formatting
avizon-aws 243001b
add mxfp8 and nvfp4 to Llama eval scripts (#3394)
vkuzo 88e2bb9
flip mx inference scaling setting to RCEIL (#3428)
vkuzo ba74266
add CLAUDE.local.md to gitignore (#3437)
vkuzo 11b2401
bump python version in tutorial ci workflow (#3439)
danielvegamyhre 6081c0c
[CPU] Reland qconv fp8 fusion passes (#3433)
Xia-Weiwen 74b84e2
Int8Tensor migration cleanup (#3407)
jcaip 3adc286
[xpu][test] Port 2 test/dtypes_{floatx, bitpacking} UT files to intel…
zxd1997066 24059b0
[xpu][test] Port 2 test/quantization/pt2e/test_{quantize_pt2e, quanti…
zxd1997066 565e813
[Intel GPU] Enable optim SR test (#3055)
arlesniak a99255a
updated test with rebase changes
avizon-aws 49dd2ce
Merge branch 'pytorch:main' into mxfp8_ag_feature
avizon-aws 248a403
added checks to run only on CUDA with compatibility >=9
avizon-aws 08a03ba
updated test for H100
avizon-aws eab336b
added test to workflow
avizon-aws File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| import pytest | ||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from torchao.prototype.mx_formats.mx_tensor import MXTensor | ||
| from torchao.utils import is_sm_at_least_90, torch_version_at_least | ||
|
|
||
| if not torch_version_at_least("2.7.0"): | ||
| pytest.skip("Unsupported PyTorch version", allow_module_level=True) | ||
|
|
||
|
|
||
| def setup_distributed(): | ||
| dist.init_process_group("nccl") | ||
| # seed must be the same in all processes | ||
| torch.manual_seed(42) | ||
| local_rank = torch.distributed.get_rank() | ||
| torch.cuda.set_device(local_rank) | ||
| return local_rank | ||
|
|
||
|
|
||
| def _test_allgather(local_rank): | ||
| golden_qdata = ( | ||
| torch.randint(0, 256, (256, 512), dtype=torch.uint8) | ||
| .to(torch.float8_e5m2) | ||
| .to(local_rank) | ||
| ) | ||
|
|
||
| # Random scale factors (typically float32 or uint8 for e8m0) | ||
| golden_scale = ( | ||
| torch.randint(0, 256, (256, 16), dtype=torch.uint8) | ||
| .view(torch.float8_e8m0fnu) | ||
| .to(local_rank) | ||
| ) | ||
|
|
||
| # Create golden MXTensor | ||
| golden_mx = MXTensor( | ||
| golden_qdata, | ||
| golden_scale, | ||
| elem_dtype=torch.float8_e5m2, | ||
| block_size=32, | ||
| orig_dtype=torch.float32, | ||
| kernel_preference=None, | ||
| act_quant_kwargs=None, | ||
| is_swizzled_scales=None, | ||
| ) | ||
|
|
||
| local_rank = torch.distributed.get_rank() | ||
| world_size = torch.distributed.get_world_size() | ||
|
|
||
| # Each rank gets its shard (split along dim 0) | ||
| shard_size = golden_qdata.shape[0] // world_size # 2 rows per rank | ||
| start_idx = local_rank * shard_size | ||
| end_idx = (local_rank + 1) * shard_size | ||
|
|
||
| # Create local MXTensor from shard | ||
| local_mx = MXTensor( | ||
| golden_qdata[start_idx:end_idx].clone().to(local_rank), | ||
| golden_scale[start_idx:end_idx].clone().to(local_rank), | ||
| elem_dtype=torch.float8_e5m2, | ||
| block_size=32, | ||
| orig_dtype=torch.float32, | ||
| kernel_preference=None, | ||
| act_quant_kwargs=None, | ||
| is_swizzled_scales=None, | ||
| ) | ||
|
|
||
| # Perform all_gather | ||
| gathered_mx = torch.ops._c10d_functional.all_gather_into_tensor.default( | ||
| local_mx, | ||
| world_size, | ||
| "0", | ||
| ) | ||
| gathered_mx = torch.ops._c10d_functional.wait_tensor.default(gathered_mx) | ||
|
|
||
| # Verify type | ||
| assert isinstance(gathered_mx, MXTensor), ( | ||
| f"Expected MXTensor, got {type(gathered_mx)}" | ||
| ) | ||
|
|
||
| # Verify shape | ||
| assert gathered_mx.shape == golden_mx.shape, ( | ||
| f"Shape mismatch: {gathered_mx.shape} vs {golden_mx.shape}" | ||
| ) | ||
|
|
||
| # Verify qdata matches golden exactly | ||
| if not torch.equal(gathered_mx.qdata, golden_qdata): | ||
| assert False, "qdata mismatch" | ||
|
|
||
| # Verify scale matches golden exactly | ||
| if not torch.equal( | ||
| gathered_mx.scale.view(torch.uint8), | ||
| golden_scale.view(torch.uint8), | ||
| ): | ||
| assert False, "scale mismatch" | ||
|
|
||
| assert gathered_mx.block_size == 32 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| local_rank = setup_distributed() | ||
|
|
||
| assert is_sm_at_least_90() == True, "SM must be > 9.0" | ||
|
|
||
| try: | ||
| _test_allgather(local_rank) | ||
| except Exception as e: | ||
| raise e | ||
|
|
||
| torch.distributed.destroy_process_group() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| #!/bin/bash | ||
|
|
||
| # terminate script on first error | ||
| set -e | ||
|
|
||
| if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then | ||
| echo "Skipping test_dtensor.sh because no CUDA devices are available." | ||
| exit | ||
| fi | ||
|
|
||
| # integration tests for TP/SP | ||
| NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mxfp8_allgather.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.