diff --git a/.github/workflows/4xH100_tests.yml b/.github/workflows/4xH100_tests.yml index 72faeebebb..4ab2b98744 100644 --- a/.github/workflows/4xH100_tests.yml +++ b/.github/workflows/4xH100_tests.yml @@ -47,3 +47,4 @@ jobs: pip install . --no-build-isolation ./test/float8/test_everything_multi_gpu.sh ./test/prototype/mx_formats/test_mx_dtensor.sh + ./test/prototype/mx_formats/test_mxfp8_allgather.sh diff --git a/test/prototype/mx_formats/test_mxfp8_allgather.py b/test/prototype/mx_formats/test_mxfp8_allgather.py new file mode 100644 index 0000000000..d68d2e7f43 --- /dev/null +++ b/test/prototype/mx_formats/test_mxfp8_allgather.py @@ -0,0 +1,109 @@ +import pytest +import torch +import torch.distributed as dist + +from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.utils import is_sm_at_least_90, torch_version_at_least + +if not torch_version_at_least("2.7.0"): + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +def setup_distributed(): + dist.init_process_group("nccl") + # seed must be the same in all processes + torch.manual_seed(42) + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) + return local_rank + + +def _test_allgather(local_rank): + golden_qdata = ( + torch.randint(0, 256, (256, 512), dtype=torch.uint8) + .to(torch.float8_e5m2) + .to(local_rank) + ) + + # Random scale factors (typically float32 or uint8 for e8m0) + golden_scale = ( + torch.randint(0, 256, (256, 16), dtype=torch.uint8) + .view(torch.float8_e8m0fnu) + .to(local_rank) + ) + + # Create golden MXTensor + golden_mx = MXTensor( + golden_qdata, + golden_scale, + elem_dtype=torch.float8_e5m2, + block_size=32, + orig_dtype=torch.float32, + kernel_preference=None, + act_quant_kwargs=None, + is_swizzled_scales=None, + ) + + local_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # Each rank gets its shard (split along dim 0) + shard_size = golden_qdata.shape[0] // world_size # 2 rows per rank + start_idx = local_rank * shard_size + end_idx = (local_rank + 1) * shard_size + + # Create local MXTensor from shard + local_mx = MXTensor( + golden_qdata[start_idx:end_idx].clone().to(local_rank), + golden_scale[start_idx:end_idx].clone().to(local_rank), + elem_dtype=torch.float8_e5m2, + block_size=32, + orig_dtype=torch.float32, + kernel_preference=None, + act_quant_kwargs=None, + is_swizzled_scales=None, + ) + + # Perform all_gather + gathered_mx = torch.ops._c10d_functional.all_gather_into_tensor.default( + local_mx, + world_size, + "0", + ) + gathered_mx = torch.ops._c10d_functional.wait_tensor.default(gathered_mx) + + # Verify type + assert isinstance(gathered_mx, MXTensor), ( + f"Expected MXTensor, got {type(gathered_mx)}" + ) + + # Verify shape + assert gathered_mx.shape == golden_mx.shape, ( + f"Shape mismatch: {gathered_mx.shape} vs {golden_mx.shape}" + ) + + # Verify qdata matches golden exactly + if not torch.equal(gathered_mx.qdata, golden_qdata): + assert False, "qdata mismatch" + + # Verify scale matches golden exactly + if not torch.equal( + gathered_mx.scale.view(torch.uint8), + golden_scale.view(torch.uint8), + ): + assert False, "scale mismatch" + + assert gathered_mx.block_size == 32 + + +if __name__ == "__main__": + local_rank = setup_distributed() + + assert is_sm_at_least_90() == True, "SM must be > 9.0" + + try: + _test_allgather(local_rank) + except Exception as e: + raise e + + torch.distributed.destroy_process_group() diff --git a/test/prototype/mx_formats/test_mxfp8_allgather.sh b/test/prototype/mx_formats/test_mxfp8_allgather.sh new file mode 100644 index 0000000000..180375af40 --- /dev/null +++ b/test/prototype/mx_formats/test_mxfp8_allgather.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# terminate script on first error +set -e + +if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then + echo "Skipping test_dtensor.sh because no CUDA devices are available." + exit +fi + +# integration tests for TP/SP +NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mxfp8_allgather.py \ No newline at end of file diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index e9f7225647..9bd1897074 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -842,3 +842,82 @@ def mx_select(func, types, args, kwargs): old_mx_tensor._is_swizzled_scales, ) return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor) + + +@implements([torch.ops._c10d_functional.all_gather_into_tensor.default]) +def mx_all_gather(func, types, args, kwargs): + """ + All-gather for MXTensor + + Args: + func: The operation (all_gather_into_tensor) + types: Tensor types involved + args: (mx_tensor, group_tag, ...) + kwargs: Additional arguments + """ + mx_tensor = args[0] + group_tag = args[1] if len(args) > 1 else "default" + + # TODO: Add support for concat CC as a future optimization + + # Gather both data and scale + gathered_qdata = torch.ops._c10d_functional.all_gather_into_tensor.default( + mx_tensor.qdata, # The quantized data + group_tag, + *args[2:], + **kwargs, + ) + + gathered_scale = torch.ops._c10d_functional.all_gather_into_tensor.default( + mx_tensor.scale.view( + torch.uint8 + ), # The scale factors, Need to cast to uint8 as float8_e8m0fnu is not support for all gather. + group_tag, + *args[2:], + **kwargs, + ) + + gathered_scale = gathered_scale.view(torch.float8_e8m0fnu) + + # Return new MXTensor with gathered data + return MXTensor( + gathered_qdata, + gathered_scale, + mx_tensor._elem_dtype, + mx_tensor.block_size, + mx_tensor._orig_dtype, + mx_tensor.kernel_preference, + mx_tensor.act_quant_kwargs, + mx_tensor._is_swizzled_scales, + ) + + +@implements([torch.ops._c10d_functional.wait_tensor.default]) +def mx_wait_tensor(func, types, args, kwargs): + """ + Wait for async collective to complete on MXTensor + + This is called after collectives like all_gather to ensure + the operation has completed before using the tensor. + """ + mx_tensor = args[0] + + # Wait on both components + waited_qdata = torch.ops._c10d_functional.wait_tensor.default( + mx_tensor.qdata, *args[1:], **kwargs + ) + + waited_scale = torch.ops._c10d_functional.wait_tensor.default( + mx_tensor.scale, *args[1:], **kwargs + ) + + return MXTensor( + waited_qdata, + waited_scale, + mx_tensor._elem_dtype, + mx_tensor.block_size, + mx_tensor._orig_dtype, + mx_tensor.kernel_preference, + mx_tensor.act_quant_kwargs, + mx_tensor._is_swizzled_scales, + )