Skip to content

Commit b9bd0de

Browse files
committed
test/prototype/mx_formats/test_kernels.py
1 parent 77feb44 commit b9bd0de

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def test_triton_mxfp8_dim0_zeros():
518518

519519
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
520520
@pytest.mark.skipif(
521-
torch.cuda.is_available() and not is_sm_at_least_100(),
521+
not is_sm_at_least_100(),
522522
reason="mxfp8 requires CUDA capability 10.0 or greater",
523523
)
524524
@pytest.mark.parametrize("M", (256, 2048, 131072))

torchao/prototype/mx_formats/kernels.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,8 @@ def triton_f6_e2m3_to_bf16(x: torch.Tensor) -> torch.Tensor:
551551
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
552552

553553
assert x.is_contiguous()
554-
assert x.is_cuda and output.is_cuda
554+
assert x.is_cuda or x.device.type == "xpu"
555+
assert output.is_cuda or output.device.type == "xpu"
555556

556557
n_mx_blocks = x.shape[0]
557558
grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),)
@@ -587,7 +588,9 @@ def triton_f6_e3m2_to_bf16(x: torch.Tensor) -> torch.Tensor:
587588
output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)
588589

589590
assert x.is_contiguous()
590-
assert x.is_cuda and output.is_cuda
591+
assert x.is_cuda or x.device.type == "xpu"
592+
assert output.is_cuda or output.device.type == "xpu"
593+
591594

592595
n_mx_blocks = x.shape[0]
593596
grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),)

0 commit comments

Comments
 (0)