File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed
test/prototype/mx_formats
torchao/prototype/mx_formats Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -518,7 +518,7 @@ def test_triton_mxfp8_dim0_zeros():
518518
519519@pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
520520@pytest .mark .skipif (
521- torch . cuda . is_available () and not is_sm_at_least_100 (),
521+ not is_sm_at_least_100 (),
522522 reason = "mxfp8 requires CUDA capability 10.0 or greater" ,
523523)
524524@pytest .mark .parametrize ("M" , (256 , 2048 , 131072 ))
Original file line number Diff line number Diff line change @@ -551,7 +551,8 @@ def triton_f6_e2m3_to_bf16(x: torch.Tensor) -> torch.Tensor:
551551 output = torch .empty (* new_shape , device = x .device , dtype = torch .bfloat16 )
552552
553553 assert x .is_contiguous ()
554- assert x .is_cuda and output .is_cuda
554+ assert x .is_cuda or x .device .type == "xpu"
555+ assert output .is_cuda or output .device .type == "xpu"
555556
556557 n_mx_blocks = x .shape [0 ]
557558 grid = lambda meta : (triton .cdiv (n_mx_blocks , meta ["BLOCK_SIZE_IN" ]),)
@@ -587,7 +588,9 @@ def triton_f6_e3m2_to_bf16(x: torch.Tensor) -> torch.Tensor:
587588 output = torch .empty (* new_shape , device = x .device , dtype = torch .bfloat16 )
588589
589590 assert x .is_contiguous ()
590- assert x .is_cuda and output .is_cuda
591+ assert x .is_cuda or x .device .type == "xpu"
592+ assert output .is_cuda or output .device .type == "xpu"
593+
591594
592595 n_mx_blocks = x .shape [0 ]
593596 grid = lambda meta : (triton .cdiv (n_mx_blocks , meta ["BLOCK_SIZE_IN" ]),)
You can’t perform that action at this time.
0 commit comments