Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
get_current_accelerator_device,
is_sm_at_least_89,
is_sm_at_least_100,
torch_version_at_least,
)

torch.manual_seed(0)
_DEVICE = get_current_accelerator_device()

if not torch_version_at_least("2.8.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -395,9 +397,9 @@ def test_fp6_values(dtype_name):
[
"cpu",
pytest.param(
"cuda",
_DEVICE,
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
not torch.accelerator.is_available(), reason="GPU not available"
),
),
],
Expand Down Expand Up @@ -439,13 +441,13 @@ def triton_to_mxfp8_dim0_reference(

@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_89(),
torch.cuda.is_available() and not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
)
@pytest.mark.parametrize("M", (128, 256))
@pytest.mark.parametrize("K", (128, 256))
def test_triton_mxfp8_dim1_randn(M, K):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE)
x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
Expand All @@ -454,13 +456,13 @@ def test_triton_mxfp8_dim1_randn(M, K):

@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100(),
torch.cuda.is_available() and not is_sm_at_least_100(),
reason="mxfp8 requires CUDA capability 10.0 or greater",
)
@pytest.mark.parametrize("M", (128, 256))
@pytest.mark.parametrize("K", (128, 256))
def test_triton_mxfp8_dim0_randn(M, K):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE)
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
Expand All @@ -469,11 +471,11 @@ def test_triton_mxfp8_dim0_randn(M, K):

@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100(),
torch.cuda.is_available() and not is_sm_at_least_100(),
reason="mxfp8 requires CUDA capability 10.0 or greater",
)
def test_triton_mxfp8_dim0_zeros():
x = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda")
x = torch.zeros(128, 256, dtype=torch.bfloat16, device=_DEVICE)
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
Expand All @@ -490,7 +492,7 @@ def test_triton_mxfp8_dim0_zeros():
@pytest.mark.parametrize("K", (128, 256))
@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16))
def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
x = torch.zeros(M, K, dtype=orig_dtype, device="cuda")
x = torch.zeros(M, K, dtype=orig_dtype, device=_DEVICE)
block_size = 32
x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32)
hp_ref = to_dtype(
Expand All @@ -504,7 +506,7 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.parametrize(
"shape",
[
Expand All @@ -519,7 +521,7 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
],
)
def test_rearrange(shape):
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
scales = torch.randint(256, size=shape, device=_DEVICE, dtype=torch.uint8)
eager = to_blocked(scales, False)
triton = to_blocked(scales, True)
torch.testing.assert_close(eager, triton, atol=0, rtol=0)
Expand All @@ -545,7 +547,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):

# Use disinct incrementing values from 0 to M*K-1 to make debugging easier.
x = (
torch.arange(0, M * K, dtype=input_dtype, device="cuda")
torch.arange(0, M * K, dtype=input_dtype, device=_DEVICE)
.reshape(M, K)
.contiguous()
)
Expand Down Expand Up @@ -583,7 +585,7 @@ def test_cuda_mx_dim0_not_supported():
M, K = 64, 64
block_size = 32
x = (
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
torch.arange(0, M * K, dtype=torch.bfloat16, device=_DEVICE)
.reshape(M, K)
.contiguous()
)
Expand All @@ -606,7 +608,7 @@ def test_cuda_mx_dim1_invalid_block_size():

M, K = 64, 64
x = (
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
torch.arange(0, M * K, dtype=torch.bfloat16, device=_DEVICE)
.reshape(M, K)
.contiguous()
)
Expand Down
Loading
Loading