Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/4xH100_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ jobs:
pip install . --no-build-isolation
./test/float8/test_everything_multi_gpu.sh
./test/prototype/mx_formats/test_mx_dtensor.sh
./test/prototype/mx_formats/test_mxfp8_allgather.sh
109 changes: 109 additions & 0 deletions test/prototype/mx_formats/test_mxfp8_allgather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import pytest
import torch
import torch.distributed as dist

from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import is_sm_at_least_90, torch_version_at_least

if not torch_version_at_least("2.7.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


def setup_distributed():
dist.init_process_group("nccl")
# seed must be the same in all processes
torch.manual_seed(42)
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
return local_rank


def _test_allgather(local_rank):
golden_qdata = (
torch.randint(0, 256, (256, 512), dtype=torch.uint8)
.to(torch.float8_e5m2)
.to(local_rank)
)

# Random scale factors (typically float32 or uint8 for e8m0)
golden_scale = (
torch.randint(0, 256, (256, 16), dtype=torch.uint8)
.view(torch.float8_e8m0fnu)
.to(local_rank)
)

# Create golden MXTensor
golden_mx = MXTensor(
golden_qdata,
golden_scale,
elem_dtype=torch.float8_e5m2,
block_size=32,
orig_dtype=torch.float32,
kernel_preference=None,
act_quant_kwargs=None,
is_swizzled_scales=None,
)

local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

# Each rank gets its shard (split along dim 0)
shard_size = golden_qdata.shape[0] // world_size # 2 rows per rank
start_idx = local_rank * shard_size
end_idx = (local_rank + 1) * shard_size

# Create local MXTensor from shard
local_mx = MXTensor(
golden_qdata[start_idx:end_idx].clone().to(local_rank),
golden_scale[start_idx:end_idx].clone().to(local_rank),
elem_dtype=torch.float8_e5m2,
block_size=32,
orig_dtype=torch.float32,
kernel_preference=None,
act_quant_kwargs=None,
is_swizzled_scales=None,
)

# Perform all_gather
gathered_mx = torch.ops._c10d_functional.all_gather_into_tensor.default(
local_mx,
world_size,
"0",
)
gathered_mx = torch.ops._c10d_functional.wait_tensor.default(gathered_mx)

# Verify type
assert isinstance(gathered_mx, MXTensor), (
f"Expected MXTensor, got {type(gathered_mx)}"
)

# Verify shape
assert gathered_mx.shape == golden_mx.shape, (
f"Shape mismatch: {gathered_mx.shape} vs {golden_mx.shape}"
)

# Verify qdata matches golden exactly
if not torch.equal(gathered_mx.qdata, golden_qdata):
assert False, "qdata mismatch"

# Verify scale matches golden exactly
if not torch.equal(
gathered_mx.scale.view(torch.uint8),
golden_scale.view(torch.uint8),
):
assert False, "scale mismatch"

assert gathered_mx.block_size == 32


if __name__ == "__main__":
local_rank = setup_distributed()

assert is_sm_at_least_90() == True, "SM must be > 9.0"

try:
_test_allgather(local_rank)
except Exception as e:
raise e

torch.distributed.destroy_process_group()
12 changes: 12 additions & 0 deletions test/prototype/mx_formats/test_mxfp8_allgather.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash

# terminate script on first error
set -e

if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
echo "Skipping test_dtensor.sh because no CUDA devices are available."
exit
fi

# integration tests for TP/SP
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/prototype/mx_formats/test_mxfp8_allgather.py
79 changes: 79 additions & 0 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,3 +842,82 @@ def mx_select(func, types, args, kwargs):
old_mx_tensor._is_swizzled_scales,
)
return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor)


@implements([torch.ops._c10d_functional.all_gather_into_tensor.default])
def mx_all_gather(func, types, args, kwargs):
"""
All-gather for MXTensor

Args:
func: The operation (all_gather_into_tensor)
types: Tensor types involved
args: (mx_tensor, group_tag, ...)
kwargs: Additional arguments
"""
mx_tensor = args[0]
group_tag = args[1] if len(args) > 1 else "default"

# TODO: Add support for concat CC as a future optimization

# Gather both data and scale
gathered_qdata = torch.ops._c10d_functional.all_gather_into_tensor.default(
mx_tensor.qdata, # The quantized data
group_tag,
*args[2:],
**kwargs,
)

gathered_scale = torch.ops._c10d_functional.all_gather_into_tensor.default(
mx_tensor.scale.view(
torch.uint8
), # The scale factors, Need to cast to uint8 as float8_e8m0fnu is not support for all gather.
group_tag,
*args[2:],
**kwargs,
)

gathered_scale = gathered_scale.view(torch.float8_e8m0fnu)

# Return new MXTensor with gathered data
return MXTensor(
gathered_qdata,
gathered_scale,
mx_tensor._elem_dtype,
mx_tensor.block_size,
mx_tensor._orig_dtype,
mx_tensor.kernel_preference,
mx_tensor.act_quant_kwargs,
mx_tensor._is_swizzled_scales,
)


@implements([torch.ops._c10d_functional.wait_tensor.default])
def mx_wait_tensor(func, types, args, kwargs):
"""
Wait for async collective to complete on MXTensor

This is called after collectives like all_gather to ensure
the operation has completed before using the tensor.
"""
mx_tensor = args[0]

# Wait on both components
waited_qdata = torch.ops._c10d_functional.wait_tensor.default(
mx_tensor.qdata, *args[1:], **kwargs
)

waited_scale = torch.ops._c10d_functional.wait_tensor.default(
mx_tensor.scale, *args[1:], **kwargs
)

return MXTensor(
waited_qdata,
waited_scale,
mx_tensor._elem_dtype,
mx_tensor.block_size,
mx_tensor._orig_dtype,
mx_tensor.kernel_preference,
mx_tensor.act_quant_kwargs,
mx_tensor._is_swizzled_scales,
)
Loading